human/src/hand/handpose.js

85 lines
3.0 KiB
JavaScript
Raw Normal View History

2020-11-04 07:11:24 +01:00
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
2020-11-04 20:59:30 +01:00
// https://storage.googleapis.com/tfjs-models/demos/handpose/index.html
2020-11-10 02:13:38 +01:00
import * as tf from '@tensorflow/tfjs/dist/tf.es2017.js';
import * as handdetector from './handdetector';
import * as pipeline from './handpipeline';
import * as anchors from './anchors';
2020-11-04 07:11:24 +01:00
const MESH_ANNOTATIONS = {
thumb: [1, 2, 3, 4],
indexFinger: [5, 6, 7, 8],
middleFinger: [9, 10, 11, 12],
ringFinger: [13, 14, 15, 16],
pinky: [17, 18, 19, 20],
palmBase: [0],
};
2020-10-12 01:22:43 +02:00
class HandPose {
2020-11-04 07:11:24 +01:00
constructor(pipe) {
this.pipeline = pipe;
}
static getAnnotations() {
return MESH_ANNOTATIONS;
2020-10-12 01:22:43 +02:00
}
async estimateHands(input, config) {
const predictions = await this.pipeline.estimateHands(input, config);
2020-11-04 07:11:24 +01:00
if (!predictions) return [];
2020-10-14 17:43:33 +02:00
const hands = [];
for (const prediction of predictions) {
const annotations = {};
2020-11-04 20:59:30 +01:00
if (prediction.landmarks) {
for (const key of Object.keys(MESH_ANNOTATIONS)) {
annotations[key] = MESH_ANNOTATIONS[key].map((index) => prediction.landmarks[index]);
}
2020-10-14 17:43:33 +02:00
}
hands.push({
2020-11-08 15:56:02 +01:00
confidence: prediction.confidence,
box: prediction.box ? [
prediction.box.topLeft[0],
prediction.box.topLeft[1],
prediction.box.bottomRight[0] - prediction.box.topLeft[0],
prediction.box.bottomRight[1] - prediction.box.topLeft[1],
2020-11-04 20:59:30 +01:00
] : 0,
2020-10-14 17:43:33 +02:00
landmarks: prediction.landmarks,
annotations,
});
2020-10-12 01:22:43 +02:00
}
2020-10-14 17:43:33 +02:00
return hands;
2020-10-12 01:22:43 +02:00
}
}
exports.HandPose = HandPose;
async function load(config) {
2020-11-03 15:34:36 +01:00
const [handDetectorModel, handPoseModel] = await Promise.all([
tf.loadGraphModel(config.detector.modelPath, { fromTFHub: config.detector.modelPath.includes('tfhub.dev') }),
tf.loadGraphModel(config.skeleton.modelPath, { fromTFHub: config.skeleton.modelPath.includes('tfhub.dev') }),
]);
2020-11-04 07:11:24 +01:00
const detector = new handdetector.HandDetector(handDetectorModel, config.inputSize, anchors.anchors);
const pipe = new pipeline.HandPipeline(detector, handPoseModel, config.inputSize);
const handpose = new HandPose(pipe);
2020-11-07 16:37:19 +01:00
// eslint-disable-next-line no-console
console.log(`Human: load model: ${config.detector.modelPath.match(/\/(.*)\./)[1]}`);
// eslint-disable-next-line no-console
console.log(`Human: load model: ${config.skeleton.modelPath.match(/\/(.*)\./)[1]}`);
return handpose;
}
exports.load = load;