human/src/hand/handpose.ts

104 lines
4.3 KiB
TypeScript
Raw Normal View History

2021-05-25 14:58:20 +02:00
/**
* HandPose model implementation
*
* Based on: [**MediaPipe HandPose**](https://drive.google.com/file/d/1sv4sSb9BSNVZhLzxXJ0jBv9DqD-4jnAz/view)
2021-05-25 14:58:20 +02:00
*/
2022-01-17 17:03:21 +01:00
import { log } from '../util/util';
import * as handdetector from './handposedetector';
import * as handpipeline from './handposepipeline';
import * as fingerPose from './fingerpose';
2022-01-16 15:49:55 +01:00
import { loadModel } from '../tfjs/load';
2021-09-27 15:19:43 +02:00
import type { HandResult, Box, Point } from '../result';
2021-09-13 19:28:35 +02:00
import type { Tensor, GraphModel } from '../tfjs/types';
import type { Config } from '../config';
2021-09-27 19:58:13 +02:00
import { env } from '../util/env';
2020-11-04 07:11:24 +01:00
2021-04-25 22:56:10 +02:00
const meshAnnotations = {
2020-11-04 07:11:24 +01:00
thumb: [1, 2, 3, 4],
index: [5, 6, 7, 8],
middle: [9, 10, 11, 12],
ring: [13, 14, 15, 16],
2020-11-04 07:11:24 +01:00
pinky: [17, 18, 19, 20],
palm: [0],
2020-11-04 07:11:24 +01:00
};
2020-10-12 01:22:43 +02:00
let handDetectorModel: GraphModel | null;
let handPoseModel: GraphModel | null;
let handPipeline: handpipeline.HandPipeline;
2020-10-12 01:22:43 +02:00
export function initPipeline() {
const handDetector = handDetectorModel ? new handdetector.HandDetector(handDetectorModel) : undefined;
if (handDetector && handPoseModel) handPipeline = new handpipeline.HandPipeline(handDetector, handPoseModel);
}
2021-09-12 05:54:35 +02:00
export async function predict(input: Tensor, config: Config): Promise<HandResult[]> {
if (!handPipeline) initPipeline();
2021-04-25 22:56:10 +02:00
const predictions = await handPipeline.estimateHands(input, config);
if (!predictions) return [];
2022-08-21 19:34:51 +02:00
const hands: HandResult[] = [];
2021-05-22 18:33:19 +02:00
for (let i = 0; i < predictions.length; i++) {
2021-04-25 22:56:10 +02:00
const annotations = {};
2021-05-22 18:33:19 +02:00
if (predictions[i].landmarks) {
2021-04-25 22:56:10 +02:00
for (const key of Object.keys(meshAnnotations)) {
2021-05-22 18:33:19 +02:00
annotations[key] = meshAnnotations[key].map((index) => predictions[i].landmarks[index]);
2020-10-14 17:43:33 +02:00
}
2020-10-12 01:22:43 +02:00
}
2022-08-21 19:34:51 +02:00
const keypoints = predictions[i].landmarks as unknown as Point[];
2021-09-27 15:19:43 +02:00
let box: Box = [Number.MAX_SAFE_INTEGER, Number.MAX_SAFE_INTEGER, 0, 0]; // maximums so conditionals work
let boxRaw: Box = [0, 0, 0, 0];
if (keypoints && keypoints.length > 0) { // if we have landmarks, calculate box based on landmarks
for (const pt of keypoints) {
2021-05-25 14:58:20 +02:00
if (pt[0] < box[0]) box[0] = pt[0];
if (pt[1] < box[1]) box[1] = pt[1];
if (pt[0] > box[2]) box[2] = pt[0];
if (pt[1] > box[3]) box[3] = pt[1];
}
box[2] -= box[0];
box[3] -= box[1];
2021-06-03 15:41:53 +02:00
boxRaw = [box[0] / (input.shape[2] || 0), box[1] / (input.shape[1] || 0), box[2] / (input.shape[2] || 0), box[3] / (input.shape[1] || 0)];
2021-05-25 14:58:20 +02:00
} else { // otherwise use box from prediction
box = predictions[i].box ? [
2021-06-01 14:59:09 +02:00
Math.trunc(Math.max(0, predictions[i].box.topLeft[0])),
Math.trunc(Math.max(0, predictions[i].box.topLeft[1])),
2021-06-03 15:41:53 +02:00
Math.trunc(Math.min((input.shape[2] || 0), predictions[i].box.bottomRight[0]) - Math.max(0, predictions[i].box.topLeft[0])),
Math.trunc(Math.min((input.shape[1] || 0), predictions[i].box.bottomRight[1]) - Math.max(0, predictions[i].box.topLeft[1])),
2021-05-25 14:58:20 +02:00
] : [0, 0, 0, 0];
boxRaw = [
2021-06-03 15:41:53 +02:00
(predictions[i].box.topLeft[0]) / (input.shape[2] || 0),
(predictions[i].box.topLeft[1]) / (input.shape[1] || 0),
(predictions[i].box.bottomRight[0] - predictions[i].box.topLeft[0]) / (input.shape[2] || 0),
(predictions[i].box.bottomRight[1] - predictions[i].box.topLeft[1]) / (input.shape[1] || 0),
2021-05-25 14:58:20 +02:00
];
}
const landmarks = fingerPose.analyze(keypoints);
hands.push({
id: i,
score: Math.round(100 * predictions[i].confidence) / 100,
2021-09-21 22:48:16 +02:00
boxScore: Math.round(100 * predictions[i].boxConfidence) / 100,
fingerScore: Math.round(100 * predictions[i].fingerConfidence) / 100,
label: 'hand',
box,
boxRaw,
keypoints,
2021-09-12 05:54:35 +02:00
annotations: annotations as HandResult['annotations'],
landmarks: landmarks as HandResult['landmarks'],
});
2020-10-12 01:22:43 +02:00
}
2021-04-25 22:56:10 +02:00
return hands;
2020-10-12 01:22:43 +02:00
}
export async function loadDetect(config: Config): Promise<GraphModel> {
if (env.initial) handDetectorModel = null;
if (!handDetectorModel) handDetectorModel = await loadModel(config.hand.detector?.modelPath);
else if (config.debug) log('cached model:', handDetectorModel['modelUrl']);
return handDetectorModel;
}
export async function loadSkeleton(config: Config): Promise<GraphModel> {
if (env.initial) handPoseModel = null;
if (!handPoseModel) handPoseModel = await loadModel(config.hand.skeleton?.modelPath);
else if (config.debug) log('cached model:', handPoseModel['modelUrl']);
return handPoseModel;
}