human/src/handpose/handpose.ts

69 lines
2.5 KiB
TypeScript
Raw Normal View History

2020-11-04 20:59:30 +01:00
// https://storage.googleapis.com/tfjs-models/demos/handpose/index.html
2021-02-08 17:39:09 +01:00
import { log } from '../log';
2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-11-10 02:13:38 +01:00
import * as handdetector from './handdetector';
2020-12-10 21:46:45 +01:00
import * as handpipeline from './handpipeline';
2020-11-10 02:13:38 +01:00
import * as anchors from './anchors';
2020-11-04 07:11:24 +01:00
const MESH_ANNOTATIONS = {
thumb: [1, 2, 3, 4],
indexFinger: [5, 6, 7, 8],
middleFinger: [9, 10, 11, 12],
ringFinger: [13, 14, 15, 16],
pinky: [17, 18, 19, 20],
palmBase: [0],
};
2020-10-12 01:22:43 +02:00
2021-02-08 17:39:09 +01:00
export class HandPose {
handPipeline: any;
2020-12-10 21:46:45 +01:00
constructor(handPipeline) {
this.handPipeline = handPipeline;
2020-11-04 07:11:24 +01:00
}
static getAnnotations() {
return MESH_ANNOTATIONS;
2020-10-12 01:22:43 +02:00
}
async estimateHands(input, config) {
2020-12-10 21:46:45 +01:00
const predictions = await this.handPipeline.estimateHands(input, config);
2020-11-04 07:11:24 +01:00
if (!predictions) return [];
2020-10-14 17:43:33 +02:00
const hands = [];
for (const prediction of predictions) {
const annotations = {};
2020-11-04 20:59:30 +01:00
if (prediction.landmarks) {
for (const key of Object.keys(MESH_ANNOTATIONS)) {
annotations[key] = MESH_ANNOTATIONS[key].map((index) => prediction.landmarks[index]);
}
2020-10-14 17:43:33 +02:00
}
const box = prediction.box ? [
Math.max(0, prediction.box.topLeft[0]),
Math.max(0, prediction.box.topLeft[1]),
Math.min(input.shape[2], prediction.box.bottomRight[0]) - prediction.box.topLeft[0],
Math.min(input.shape[1], prediction.box.bottomRight[1]) - prediction.box.topLeft[1],
] : 0;
2020-10-14 17:43:33 +02:00
hands.push({
2020-11-08 15:56:02 +01:00
confidence: prediction.confidence,
box,
2020-10-14 17:43:33 +02:00
landmarks: prediction.landmarks,
annotations,
});
2020-10-12 01:22:43 +02:00
}
2020-10-14 17:43:33 +02:00
return hands;
2020-10-12 01:22:43 +02:00
}
}
2021-02-08 17:39:09 +01:00
export async function load(config) {
2020-11-03 15:34:36 +01:00
const [handDetectorModel, handPoseModel] = await Promise.all([
2020-11-24 04:55:01 +01:00
config.hand.enabled ? tf.loadGraphModel(config.hand.detector.modelPath, { fromTFHub: config.hand.detector.modelPath.includes('tfhub.dev') }) : null,
config.hand.landmarks ? tf.loadGraphModel(config.hand.skeleton.modelPath, { fromTFHub: config.hand.skeleton.modelPath.includes('tfhub.dev') }) : null,
]);
2020-12-10 21:46:45 +01:00
const handDetector = new handdetector.HandDetector(handDetectorModel, config.hand.inputSize, anchors.anchors);
const handPipeline = new handpipeline.HandPipeline(handDetector, handPoseModel, config.hand.inputSize);
const handPose = new HandPose(handPipeline);
if (config.hand.enabled) log(`load model: ${config.hand.detector.modelPath.match(/\/(.*)\./)[1]}`);
if (config.hand.landmarks) log(`load model: ${config.hand.skeleton.modelPath.match(/\/(.*)\./)[1]}`);
2020-12-10 21:46:45 +01:00
return handPose;
}