human/src/handpose/handpose.ts

68 lines
3.1 KiB
TypeScript
Raw Normal View History

2021-04-09 14:07:58 +02:00
import { log, join } from '../helpers';
2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-11-10 02:13:38 +01:00
import * as handdetector from './handdetector';
2020-12-10 21:46:45 +01:00
import * as handpipeline from './handpipeline';
2021-05-22 18:33:19 +02:00
import { Hand } from '../result';
2020-11-04 07:11:24 +01:00
2021-04-25 22:56:10 +02:00
const meshAnnotations = {
2020-11-04 07:11:24 +01:00
thumb: [1, 2, 3, 4],
indexFinger: [5, 6, 7, 8],
middleFinger: [9, 10, 11, 12],
ringFinger: [13, 14, 15, 16],
pinky: [17, 18, 19, 20],
palmBase: [0],
};
2020-10-12 01:22:43 +02:00
2021-04-25 22:56:10 +02:00
let handDetectorModel;
let handPoseModel;
let handPipeline;
2020-10-12 01:22:43 +02:00
2021-05-22 18:33:19 +02:00
export async function predict(input, config): Promise<Hand[]> {
2021-04-25 22:56:10 +02:00
const predictions = await handPipeline.estimateHands(input, config);
if (!predictions) return [];
2021-05-22 18:33:19 +02:00
const hands: Array<Hand> = [];
for (let i = 0; i < predictions.length; i++) {
2021-04-25 22:56:10 +02:00
const annotations = {};
2021-05-22 18:33:19 +02:00
if (predictions[i].landmarks) {
2021-04-25 22:56:10 +02:00
for (const key of Object.keys(meshAnnotations)) {
2021-05-22 18:33:19 +02:00
annotations[key] = meshAnnotations[key].map((index) => predictions[i].landmarks[index]);
2020-10-14 17:43:33 +02:00
}
2020-10-12 01:22:43 +02:00
}
2021-05-22 18:33:19 +02:00
const box: [number, number, number, number] = predictions[i].box ? [
Math.max(0, predictions[i].box.topLeft[0]),
Math.max(0, predictions[i].box.topLeft[1]),
Math.min(input.shape[2], predictions[i].box.bottomRight[0]) - Math.max(0, predictions[i].box.topLeft[0]),
Math.min(input.shape[1], predictions[i].box.bottomRight[1]) - Math.max(0, predictions[i].box.topLeft[1]),
] : [0, 0, 0, 0];
const boxRaw: [number, number, number, number] = [
(predictions[i].box.topLeft[0]) / input.shape[2],
(predictions[i].box.topLeft[1]) / input.shape[1],
(predictions[i].box.bottomRight[0] - predictions[i].box.topLeft[0]) / input.shape[2],
(predictions[i].box.bottomRight[1] - predictions[i].box.topLeft[1]) / input.shape[1],
2021-04-25 22:56:10 +02:00
];
2021-05-22 18:33:19 +02:00
hands.push({ id: i, confidence: Math.round(100 * predictions[i].confidence) / 100, box, boxRaw, landmarks: predictions[i].landmarks, annotations });
2020-10-12 01:22:43 +02:00
}
2021-04-25 22:56:10 +02:00
return hands;
2020-10-12 01:22:43 +02:00
}
2021-04-25 22:56:10 +02:00
export async function load(config): Promise<[Object, Object]> {
if (!handDetectorModel || !handPoseModel) {
[handDetectorModel, handPoseModel] = await Promise.all([
config.hand.enabled ? tf.loadGraphModel(join(config.modelBasePath, config.hand.detector.modelPath), { fromTFHub: config.hand.detector.modelPath.includes('tfhub.dev') }) : null,
config.hand.landmarks ? tf.loadGraphModel(join(config.modelBasePath, config.hand.skeleton.modelPath), { fromTFHub: config.hand.skeleton.modelPath.includes('tfhub.dev') }) : null,
]);
if (config.hand.enabled) {
if (!handDetectorModel || !handDetectorModel.modelUrl) log('load model failed:', config.hand.detector.modelPath);
else if (config.debug) log('load model:', handDetectorModel.modelUrl);
if (!handPoseModel || !handPoseModel.modelUrl) log('load model failed:', config.hand.skeleton.modelPath);
else if (config.debug) log('load model:', handPoseModel.modelUrl);
}
} else {
if (config.debug) log('cached model:', handDetectorModel.modelUrl);
if (config.debug) log('cached model:', handPoseModel.modelUrl);
}
2021-04-25 22:56:10 +02:00
const handDetector = new handdetector.HandDetector(handDetectorModel);
handPipeline = new handpipeline.HandPipeline(handDetector, handPoseModel);
return [handDetectorModel, handPoseModel];
}