human/src/body/movenet.ts

197 lines
8.7 KiB
TypeScript
Raw Normal View History

/**
* MoveNet model implementation
*
* Based on: [**MoveNet**](https://blog.tensorflow.org/2021/05/next-generation-pose-detection-with-movenet-and-tensorflowjs.html)
*/
2021-09-27 19:58:13 +02:00
import { log, join } from '../util/util';
import * as box from '../util/box';
import * as tf from '../../dist/tfjs.esm.js';
import * as coords from './movenetcoords';
2021-10-14 18:26:59 +02:00
import * as fix from './movenetfix';
import type { BodyKeypoint, BodyResult, Box, Point } from '../result';
2021-09-13 19:28:35 +02:00
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
2021-09-26 01:14:03 +02:00
import { fakeOps } from '../tfjs/backend';
2021-09-27 19:58:13 +02:00
import { env } from '../util/env';
2021-09-17 17:23:00 +02:00
let model: GraphModel | null;
2021-09-27 14:53:41 +02:00
let inputSize = 0;
2021-10-14 18:26:59 +02:00
let skipped = Number.MAX_SAFE_INTEGER;
// const boxExpandFact = 1.5; // increase to 150%
const cache: {
2021-10-14 18:26:59 +02:00
boxes: Array<Box>, // unused
bodies: Array<BodyResult>;
} = {
boxes: [],
bodies: [],
};
2021-06-03 15:41:53 +02:00
export async function load(config: Config): Promise<GraphModel> {
2021-09-17 17:23:00 +02:00
if (env.initial) model = null;
if (!model) {
2021-09-26 01:14:03 +02:00
fakeOps(['size'], config);
2021-09-12 05:54:35 +02:00
model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
2021-09-27 14:53:41 +02:00
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
if (inputSize === -1) inputSize = 256;
return model;
}
2021-09-27 14:53:41 +02:00
async function parseSinglePose(res, config, image, inputBox) {
2021-08-20 15:05:07 +02:00
const kpt = res[0][0];
2021-10-14 18:26:59 +02:00
const keypoints: Array<BodyKeypoint> = [];
2021-09-27 15:19:43 +02:00
let score = 0;
2021-08-20 15:05:07 +02:00
for (let id = 0; id < kpt.length; id++) {
score = kpt[id][2];
if (score > config.body.minConfidence) {
2021-09-27 20:39:54 +02:00
const positionRaw: Point = [
2021-09-27 14:53:41 +02:00
(inputBox[3] - inputBox[1]) * kpt[id][1] + inputBox[1],
(inputBox[2] - inputBox[0]) * kpt[id][0] + inputBox[0],
];
2021-08-20 15:05:07 +02:00
keypoints.push({
score: Math.round(100 * score) / 100,
part: coords.kpt[id],
2021-09-27 14:53:41 +02:00
positionRaw,
2021-08-20 15:05:07 +02:00
position: [ // normalized to input image size
2021-09-27 14:53:41 +02:00
Math.round((image.shape[2] || 0) * positionRaw[0]),
Math.round((image.shape[1] || 0) * positionRaw[1]),
2021-08-20 15:05:07 +02:00
],
});
}
}
score = keypoints.reduce((prev, curr) => (curr.score > prev ? curr.score : prev), 0);
const bodies: Array<BodyResult> = [];
const newBox = box.calc(keypoints.map((pt) => pt.position), [image.shape[2], image.shape[1]]);
const annotations: Record<string, Point[][]> = {};
for (const [name, indexes] of Object.entries(coords.connected)) {
const pt: Array<Point[]> = [];
for (let i = 0; i < indexes.length - 1; i++) {
const pt0 = keypoints.find((kp) => kp.part === indexes[i]);
const pt1 = keypoints.find((kp) => kp.part === indexes[i + 1]);
if (pt0 && pt1 && pt0.score > (config.body.minConfidence || 0) && pt1.score > (config.body.minConfidence || 0)) pt.push([pt0.position, pt1.position]);
}
annotations[name] = pt;
}
2021-10-14 18:26:59 +02:00
const body: BodyResult = { id: 0, score, box: newBox.box, boxRaw: newBox.boxRaw, keypoints, annotations };
fix.bodyParts(body);
bodies.push(body);
2021-09-27 14:53:41 +02:00
return bodies;
2021-08-20 15:05:07 +02:00
}
2021-09-27 14:53:41 +02:00
async function parseMultiPose(res, config, image, inputBox) {
const bodies: Array<BodyResult> = [];
2021-09-26 01:14:03 +02:00
for (let id = 0; id < res[0].length; id++) {
const kpt = res[0][id];
2021-09-28 23:07:34 +02:00
const totalScore = Math.round(100 * kpt[51 + 4]) / 100;
if (totalScore > config.body.minConfidence) {
2021-10-14 18:26:59 +02:00
const keypoints: Array<BodyKeypoint> = [];
2021-09-28 23:07:34 +02:00
for (let i = 0; i < 17; i++) {
const score = kpt[3 * i + 2];
if (score > config.body.minConfidence) {
const positionRaw: Point = [
(inputBox[3] - inputBox[1]) * kpt[3 * i + 1] + inputBox[1],
(inputBox[2] - inputBox[0]) * kpt[3 * i + 0] + inputBox[0],
];
keypoints.push({
part: coords.kpt[i],
2021-09-28 23:07:34 +02:00
score: Math.round(100 * score) / 100,
positionRaw,
position: [Math.round((image.shape[2] || 0) * positionRaw[0]), Math.round((image.shape[1] || 0) * positionRaw[1])],
2021-09-28 23:07:34 +02:00
});
}
2021-08-20 15:05:07 +02:00
}
const newBox = box.calc(keypoints.map((pt) => pt.position), [image.shape[2], image.shape[1]]);
2021-09-28 23:07:34 +02:00
// movenet-multipose has built-in box details
// const boxRaw: Box = [kpt[51 + 1], kpt[51 + 0], kpt[51 + 3] - kpt[51 + 1], kpt[51 + 2] - kpt[51 + 0]];
// const box: Box = [Math.trunc(boxRaw[0] * (image.shape[2] || 0)), Math.trunc(boxRaw[1] * (image.shape[1] || 0)), Math.trunc(boxRaw[2] * (image.shape[2] || 0)), Math.trunc(boxRaw[3] * (image.shape[1] || 0))];
const annotations: Record<string, Point[][]> = {};
for (const [name, indexes] of Object.entries(coords.connected)) {
const pt: Array<Point[]> = [];
for (let i = 0; i < indexes.length - 1; i++) {
const pt0 = keypoints.find((kp) => kp.part === indexes[i]);
const pt1 = keypoints.find((kp) => kp.part === indexes[i + 1]);
if (pt0 && pt1 && pt0.score > (config.body.minConfidence || 0) && pt1.score > (config.body.minConfidence || 0)) pt.push([pt0.position, pt1.position]);
}
annotations[name] = pt;
}
2021-10-14 18:26:59 +02:00
const body: BodyResult = { id, score: totalScore, box: newBox.box, boxRaw: newBox.boxRaw, keypoints: [...keypoints], annotations };
fix.bodyParts(body);
bodies.push(body);
2021-08-20 15:05:07 +02:00
}
}
2021-09-28 23:07:34 +02:00
bodies.sort((a, b) => b.score - a.score);
if (bodies.length > config.body.maxDetected) bodies.length = config.body.maxDetected;
2021-09-27 14:53:41 +02:00
return bodies;
2021-08-20 15:05:07 +02:00
}
2021-09-27 14:53:41 +02:00
export async function predict(input: Tensor, config: Config): Promise<BodyResult[]> {
2021-10-10 23:52:43 +02:00
/** movenet caching
* 1. if skipFrame returned cached
* 2. if enough cached boxes run using cached boxes
* 3. if not enough detected bodies rerun using full frame
* 4. regenerate cached boxes based on current keypoints
*/
if (!model || !model?.inputs[0].shape) return []; // something is wrong with the model
if (!config.skipFrame) cache.boxes.length = 0; // allowed to use cache or not
skipped++; // increment skip frames
if (config.skipFrame && (skipped <= (config.body.skipFrames || 0))) {
return cache.bodies; // return cached results without running anything
}
return new Promise(async (resolve) => {
2021-09-27 14:53:41 +02:00
const t: Record<string, Tensor> = {};
skipped = 0;
2021-10-14 18:26:59 +02:00
// run detection on squared input and cached boxes
/*
cache.bodies = []; // reset bodies result
if (cache.boxes.length >= (config.body.maxDetected || 0)) { // if we have enough cached boxes run detection using cache
for (let i = 0; i < cache.boxes.length; i++) { // run detection based on cached boxes
t.crop = tf.image.cropAndResize(input, [cache.boxes[i]], [0], [inputSize, inputSize], 'bilinear');
t.cast = tf.cast(t.crop, 'int32');
2021-10-14 18:26:59 +02:00
// t.input = prepareImage(input);
t.res = await model?.predict(t.cast) as Tensor;
const res = await t.res.array();
const newBodies = (t.res.shape[2] === 17) ? await parseSinglePose(res, config, input, cache.boxes[i]) : await parseMultiPose(res, config, input, cache.boxes[i]);
cache.bodies = cache.bodies.concat(newBodies);
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
}
2021-09-27 14:53:41 +02:00
}
if (cache.bodies.length !== config.body.maxDetected) { // did not find enough bodies based on cached boxes so run detection on full frame
2021-10-14 18:26:59 +02:00
t.input = prepareImage(input);
t.res = await model?.predict(t.input) as Tensor;
2021-09-27 14:53:41 +02:00
const res = await t.res.array();
cache.bodies = (t.res.shape[2] === 17) ? await parseSinglePose(res, config, input, [0, 0, 1, 1]) : await parseMultiPose(res, config, input, [0, 0, 1, 1]);
2021-10-14 18:26:59 +02:00
for (const body of cache.bodies) rescaleBody(body, [input.shape[2] || 1, input.shape[1] || 1]);
2021-09-27 14:53:41 +02:00
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
}
cache.boxes.length = 0; // reset cache
for (let i = 0; i < cache.bodies.length; i++) {
if (cache.bodies[i].keypoints.length > (coords.kpt.length / 2)) { // only update cache if we detected at least half keypoints
const scaledBox = box.scale(cache.bodies[i].boxRaw, boxExpandFact);
const cropBox = box.crop(scaledBox);
cache.boxes.push(cropBox);
2021-09-27 14:53:41 +02:00
}
}
2021-10-14 18:26:59 +02:00
*/
// run detection on squared input and no cached boxes
t.input = fix.padInput(input, inputSize);
t.res = await model?.predict(t.input) as Tensor;
const res = await t.res.array();
cache.bodies = (t.res.shape[2] === 17)
? await parseSinglePose(res, config, input, [0, 0, 1, 1])
: await parseMultiPose(res, config, input, [0, 0, 1, 1]);
for (const body of cache.bodies) {
fix.rescaleBody(body, [input.shape[2] || 1, input.shape[1] || 1]);
fix.jitter(body.keypoints);
}
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
resolve(cache.bodies);
});
}