human/src/handpose/handpose.js

65 lines
2.3 KiB
JavaScript
Raw Normal View History

2020-10-12 01:22:43 +02:00
const tf = require('@tensorflow/tfjs');
2020-10-14 17:43:33 +02:00
const hand = require('./handdetector');
2020-10-12 01:22:43 +02:00
const keypoints = require('./keypoints');
const pipe = require('./pipeline');
class HandPose {
constructor(pipeline) {
this.pipeline = pipeline;
}
async estimateHands(input, config) {
2020-10-18 14:07:45 +02:00
this.skipFrames = config.skipFrames;
this.detectionConfidence = config.minConfidence;
this.maxHands = config.maxHands;
2020-10-12 01:22:43 +02:00
const image = tf.tidy(() => {
if (!(input instanceof tf.Tensor)) {
input = tf.browser.fromPixels(input);
}
return input.toFloat().expandDims(0);
});
const predictions = await this.pipeline.estimateHands(image, config);
2020-10-12 01:22:43 +02:00
image.dispose();
2020-10-14 17:43:33 +02:00
const hands = [];
if (!predictions) return hands;
for (const prediction of predictions) {
if (!prediction) return [];
const annotations = {};
for (const key of Object.keys(keypoints.MESH_ANNOTATIONS)) {
annotations[key] = keypoints.MESH_ANNOTATIONS[key].map((index) => prediction.landmarks[index]);
}
hands.push({
confidence: prediction.confidence || 0,
box: prediction.box ? [prediction.box.topLeft[0], prediction.box.topLeft[1], prediction.box.bottomRight[0] - prediction.box.topLeft[0], prediction.box.bottomRight[1] - prediction.box.topLeft[1]] : 0,
landmarks: prediction.landmarks,
annotations,
});
2020-10-12 01:22:43 +02:00
}
2020-10-14 17:43:33 +02:00
return hands;
2020-10-12 01:22:43 +02:00
}
}
exports.HandPose = HandPose;
async function loadAnchors(url) {
if (tf.env().features.IS_NODE) {
// eslint-disable-next-line global-require
const fs = require('fs');
const data = await fs.readFileSync(url.replace('file://', ''));
return JSON.parse(data);
}
return tf.util.fetch(url).then((d) => d.json());
}
async function load(config) {
const [anchors, handDetectorModel, handPoseModel] = await Promise.all([
loadAnchors(config.detector.anchors),
tf.loadGraphModel(config.detector.modelPath, { fromTFHub: config.detector.modelPath.includes('tfhub.dev') }),
tf.loadGraphModel(config.skeleton.modelPath, { fromTFHub: config.skeleton.modelPath.includes('tfhub.dev') }),
]);
const detector = new hand.HandDetector(handDetectorModel, anchors, config);
const pipeline = new pipe.HandPipeline(detector, handPoseModel, config);
const handpose = new HandPose(pipeline);
return handpose;
}
exports.load = load;