2021-05-25 14:58:20 +02:00
|
|
|
/**
|
|
|
|
* HandPose module entry point
|
|
|
|
*/
|
|
|
|
|
2021-04-09 14:07:58 +02:00
|
|
|
import { log, join } from '../helpers';
|
2020-11-18 14:26:28 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
2020-11-10 02:13:38 +01:00
|
|
|
import * as handdetector from './handdetector';
|
2020-12-10 21:46:45 +01:00
|
|
|
import * as handpipeline from './handpipeline';
|
2021-05-22 18:33:19 +02:00
|
|
|
import { Hand } from '../result';
|
2021-05-23 03:47:59 +02:00
|
|
|
import { GraphModel } from '../tfjs/types';
|
2020-11-04 07:11:24 +01:00
|
|
|
|
2021-04-25 22:56:10 +02:00
|
|
|
const meshAnnotations = {
|
2020-11-04 07:11:24 +01:00
|
|
|
thumb: [1, 2, 3, 4],
|
|
|
|
indexFinger: [5, 6, 7, 8],
|
|
|
|
middleFinger: [9, 10, 11, 12],
|
|
|
|
ringFinger: [13, 14, 15, 16],
|
|
|
|
pinky: [17, 18, 19, 20],
|
|
|
|
palmBase: [0],
|
|
|
|
};
|
2020-10-12 01:22:43 +02:00
|
|
|
|
2021-05-23 03:47:59 +02:00
|
|
|
let handDetectorModel: GraphModel | null;
|
|
|
|
let handPoseModel: GraphModel | null;
|
|
|
|
let handPipeline: handpipeline.HandPipeline;
|
2020-10-12 01:22:43 +02:00
|
|
|
|
2021-05-22 18:33:19 +02:00
|
|
|
export async function predict(input, config): Promise<Hand[]> {
|
2021-04-25 22:56:10 +02:00
|
|
|
const predictions = await handPipeline.estimateHands(input, config);
|
|
|
|
if (!predictions) return [];
|
2021-05-22 18:33:19 +02:00
|
|
|
const hands: Array<Hand> = [];
|
|
|
|
for (let i = 0; i < predictions.length; i++) {
|
2021-04-25 22:56:10 +02:00
|
|
|
const annotations = {};
|
2021-05-22 18:33:19 +02:00
|
|
|
if (predictions[i].landmarks) {
|
2021-04-25 22:56:10 +02:00
|
|
|
for (const key of Object.keys(meshAnnotations)) {
|
2021-05-23 03:47:59 +02:00
|
|
|
// @ts-ignore landmarks are not undefined
|
2021-05-22 18:33:19 +02:00
|
|
|
annotations[key] = meshAnnotations[key].map((index) => predictions[i].landmarks[index]);
|
2020-10-14 17:43:33 +02:00
|
|
|
}
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|
2021-05-25 14:58:20 +02:00
|
|
|
|
2021-05-31 05:21:48 +02:00
|
|
|
const landmarks = predictions[i].landmarks as unknown as Array<[number, number, number]>;
|
2021-05-25 14:58:20 +02:00
|
|
|
|
|
|
|
let box: [number, number, number, number] = [Number.MAX_SAFE_INTEGER, Number.MAX_SAFE_INTEGER, 0, 0]; // maximums so conditionals work
|
|
|
|
let boxRaw: [number, number, number, number] = [0, 0, 0, 0];
|
|
|
|
if (landmarks && landmarks.length > 0) { // if we have landmarks, calculate box based on landmarks
|
|
|
|
for (const pt of landmarks) {
|
|
|
|
if (pt[0] < box[0]) box[0] = pt[0];
|
|
|
|
if (pt[1] < box[1]) box[1] = pt[1];
|
|
|
|
if (pt[0] > box[2]) box[2] = pt[0];
|
|
|
|
if (pt[1] > box[3]) box[3] = pt[1];
|
|
|
|
}
|
|
|
|
box[2] -= box[0];
|
|
|
|
box[3] -= box[1];
|
|
|
|
boxRaw = [box[0] / input.shape[2], box[1] / input.shape[1], box[2] / input.shape[2], box[3] / input.shape[1]];
|
|
|
|
} else { // otherwise use box from prediction
|
|
|
|
box = predictions[i].box ? [
|
|
|
|
Math.max(0, predictions[i].box.topLeft[0]),
|
|
|
|
Math.max(0, predictions[i].box.topLeft[1]),
|
|
|
|
Math.min(input.shape[2], predictions[i].box.bottomRight[0]) - Math.max(0, predictions[i].box.topLeft[0]),
|
|
|
|
Math.min(input.shape[1], predictions[i].box.bottomRight[1]) - Math.max(0, predictions[i].box.topLeft[1]),
|
|
|
|
] : [0, 0, 0, 0];
|
|
|
|
boxRaw = [
|
|
|
|
(predictions[i].box.topLeft[0]) / input.shape[2],
|
|
|
|
(predictions[i].box.topLeft[1]) / input.shape[1],
|
|
|
|
(predictions[i].box.bottomRight[0] - predictions[i].box.topLeft[0]) / input.shape[2],
|
|
|
|
(predictions[i].box.bottomRight[1] - predictions[i].box.topLeft[1]) / input.shape[1],
|
|
|
|
];
|
|
|
|
}
|
2021-05-23 03:47:59 +02:00
|
|
|
hands.push({ id: i, confidence: Math.round(100 * predictions[i].confidence) / 100, box, boxRaw, landmarks, annotations });
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|
2021-04-25 22:56:10 +02:00
|
|
|
return hands;
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|
2020-10-14 19:23:02 +02:00
|
|
|
|
2021-05-23 03:47:59 +02:00
|
|
|
export async function load(config): Promise<[unknown, unknown]> {
|
2021-04-12 14:29:52 +02:00
|
|
|
if (!handDetectorModel || !handPoseModel) {
|
2021-05-23 03:47:59 +02:00
|
|
|
// @ts-ignore type mismatch on GraphModel
|
2021-04-12 14:29:52 +02:00
|
|
|
[handDetectorModel, handPoseModel] = await Promise.all([
|
|
|
|
config.hand.enabled ? tf.loadGraphModel(join(config.modelBasePath, config.hand.detector.modelPath), { fromTFHub: config.hand.detector.modelPath.includes('tfhub.dev') }) : null,
|
|
|
|
config.hand.landmarks ? tf.loadGraphModel(join(config.modelBasePath, config.hand.skeleton.modelPath), { fromTFHub: config.hand.skeleton.modelPath.includes('tfhub.dev') }) : null,
|
|
|
|
]);
|
|
|
|
if (config.hand.enabled) {
|
2021-05-23 03:47:59 +02:00
|
|
|
if (!handDetectorModel || !handDetectorModel['modelUrl']) log('load model failed:', config.hand.detector.modelPath);
|
|
|
|
else if (config.debug) log('load model:', handDetectorModel['modelUrl']);
|
|
|
|
if (!handPoseModel || !handPoseModel['modelUrl']) log('load model failed:', config.hand.skeleton.modelPath);
|
|
|
|
else if (config.debug) log('load model:', handPoseModel['modelUrl']);
|
2021-04-12 14:29:52 +02:00
|
|
|
}
|
|
|
|
} else {
|
2021-05-23 03:47:59 +02:00
|
|
|
if (config.debug) log('cached model:', handDetectorModel['modelUrl']);
|
|
|
|
if (config.debug) log('cached model:', handPoseModel['modelUrl']);
|
2021-04-12 14:29:52 +02:00
|
|
|
}
|
2021-04-25 22:56:10 +02:00
|
|
|
const handDetector = new handdetector.HandDetector(handDetectorModel);
|
|
|
|
handPipeline = new handpipeline.HandPipeline(handDetector, handPoseModel);
|
|
|
|
return [handDetectorModel, handPoseModel];
|
2020-10-14 19:23:02 +02:00
|
|
|
}
|