implement box caching for movenet

pull/356/head
Vladimir Mandic 2021-09-27 08:53:41 -04:00
parent 04406afcf2
commit 561d25cfc9
5 changed files with 104 additions and 87 deletions

View File

@ -105,7 +105,7 @@ const ui = {
lastFrame: 0, // time of last frame processing
viewportSet: false, // internal, has custom viewport been set
background: null, // holds instance of segmentation background image
exceptionHandler: true, // should capture all unhandled exceptions
exceptionHandler: false, // should capture all unhandled exceptions
// webrtc
useWebRTC: false, // use webrtc as camera source instead of local webcam

View File

@ -6,7 +6,7 @@
* - Hand Tracking: [**HandTracking**](https://github.com/victordibia/handtracking)
*/
import { log, join } from '../util';
import { log, join, scaleBox } from '../util';
import * as tf from '../../dist/tfjs.esm.js';
import type { HandResult } from '../result';
import type { GraphModel, Tensor } from '../tfjs/types';
@ -21,18 +21,10 @@ const modelOutputNodes = ['StatefulPartitionedCall/Postprocessor/Slice', 'Statef
const inputSize = [[0, 0], [0, 0]];
const classes = [
'hand',
'fist',
'pinch',
'point',
'face',
'tip',
'pinchtip',
];
const classes = ['hand', 'fist', 'pinch', 'point', 'face', 'tip', 'pinchtip'];
let skipped = 0;
let outputSize;
let outputSize: [number, number] = [0, 0];
type HandDetectResult = {
id: number,
@ -145,31 +137,6 @@ async function detectHands(input: Tensor, config: Config): Promise<HandDetectRes
return hands;
}
function updateBoxes(h, keypoints) {
const finger = [keypoints.map((pt) => pt[0]), keypoints.map((pt) => pt[1])]; // all fingers coords
const minmax = [Math.min(...finger[0]), Math.max(...finger[0]), Math.min(...finger[1]), Math.max(...finger[1])]; // find min and max coordinates for x and y of all fingers
const center = [(minmax[0] + minmax[1]) / 2, (minmax[2] + minmax[3]) / 2]; // find center x and y coord of all fingers
const diff = Math.max(center[0] - minmax[0], center[1] - minmax[2], -center[0] + minmax[1], -center[1] + minmax[3]) * boxScaleFact; // largest distance from center in any direction
h.box = [
Math.trunc(center[0] - diff),
Math.trunc(center[1] - diff),
Math.trunc(2 * diff),
Math.trunc(2 * diff),
] as [number, number, number, number];
h.boxRaw = [ // work backwards
h.box[0] / outputSize[0],
h.box[1] / outputSize[1],
h.box[2] / outputSize[0],
h.box[3] / outputSize[1],
] as [number, number, number, number];
h.yxBox = [ // work backwards
h.boxRaw[1],
h.boxRaw[0],
h.boxRaw[3] + h.boxRaw[1],
h.boxRaw[2] + h.boxRaw[0],
] as [number, number, number, number];
}
async function detectFingers(input: Tensor, h: HandDetectResult, config: Config): Promise<HandResult> {
const hand: HandResult = {
id: h.id,
@ -201,7 +168,10 @@ async function detectFingers(input: Tensor, h: HandDetectResult, config: Config)
(h.box[3] * coord[1] / inputSize[1][1]) + h.box[1],
(h.box[2] + h.box[3]) / 2 / inputSize[1][0] * coord[2],
]);
updateBoxes(h, hand.keypoints); // replace detected box with box calculated around keypoints
const updatedBox = scaleBox(hand.keypoints, boxScaleFact, outputSize); // replace detected box with box calculated around keypoints
h.box = updatedBox.box;
h.boxRaw = updatedBox.boxRaw;
h.yxBox = updatedBox.yxBox;
hand.box = h.box;
hand.landmarks = fingerPose.analyze(hand.keypoints) as HandResult['landmarks']; // calculate finger landmarks
for (const key of Object.keys(fingerMap)) { // map keypoints to per-finger annotations
@ -222,16 +192,13 @@ export async function predict(input: Tensor, config: Config): Promise<HandResult
if ((skipped < (config.hand.skipFrames || 0)) && config.skipFrame) { // just run finger detection while reusing cached boxes
skipped++;
hands = await Promise.all(cache.fingerBoxes.map((hand) => detectFingers(input, hand, config))); // run from finger box cache
// console.log('SKIP', skipped, hands.length, cache.handBoxes.length, cache.fingerBoxes.length, cache.tmpBoxes.length);
} else { // calculate new boxes and run finger detection
skipped = 0;
hands = await Promise.all(cache.fingerBoxes.map((hand) => detectFingers(input, hand, config))); // run from finger box cache
// console.log('CACHE', skipped, hands.length, cache.handBoxes.length, cache.fingerBoxes.length, cache.tmpBoxes.length);
if (hands.length !== config.hand.maxDetected) { // run hand detection only if we dont have enough hands in cache
cache.handBoxes = await detectHands(input, config);
const newHands = await Promise.all(cache.handBoxes.map((hand) => detectFingers(input, hand, config)));
hands = hands.concat(newHands);
// console.log('DETECT', skipped, hands.length, cache.handBoxes.length, cache.fingerBoxes.length, cache.tmpBoxes.length);
}
}
cache.fingerBoxes = [...cache.tmpBoxes]; // repopulate cache with validated hands

View File

@ -458,7 +458,7 @@ export class Human {
// run body: can be posenet, blazepose, efficientpose, movenet
this.analyze('Start Body:');
this.state = 'detect:body';
const bodyConfig = this.config.body.maxDetected === -1 ? mergeDeep(this.config, { body: { maxDetected: 1 * (faceRes as FaceResult[]).length } }) : this.config; // autodetect number of bodies
const bodyConfig = this.config.body.maxDetected === -1 ? mergeDeep(this.config, { body: { maxDetected: this.config.face.enabled ? 1 * (faceRes as FaceResult[]).length : 1 } }) : this.config; // autodetect number of bodies
if (this.config.async) {
if (this.config.body.modelPath?.includes('posenet')) bodyRes = this.config.body.enabled ? posenet.predict(img.tensor, bodyConfig) : [];
else if (this.config.body.modelPath?.includes('blazepose')) bodyRes = this.config.body.enabled ? blazepose.predict(img.tensor, bodyConfig) : [];
@ -479,7 +479,7 @@ export class Human {
// run handpose
this.analyze('Start Hand:');
this.state = 'detect:hand';
const handConfig = this.config.hand.maxDetected === -1 ? mergeDeep(this.config, { hand: { maxDetected: 2 * (faceRes as FaceResult[]).length } }) : this.config; // autodetect number of hands
const handConfig = this.config.hand.maxDetected === -1 ? mergeDeep(this.config, { hand: { maxDetected: this.config.face.enabled ? 2 * (faceRes as FaceResult[]).length : 1 } }) : this.config; // autodetect number of hands
if (this.config.async) {
if (this.config.hand.detector?.modelPath?.includes('handdetect')) handRes = this.config.hand.enabled ? handpose.predict(img.tensor, handConfig) : [];
else if (this.config.hand.detector?.modelPath?.includes('handtrack')) handRes = this.config.hand.enabled ? handtrack.predict(img.tensor, handConfig) : [];

View File

@ -4,7 +4,7 @@
* Based on: [**MoveNet**](https://blog.tensorflow.org/2021/05/next-generation-pose-detection-with-movenet-and-tensorflowjs.html)
*/
import { log, join } from '../util';
import { log, join, scaleBox } from '../util';
import * as tf from '../../dist/tfjs.esm.js';
import type { BodyResult } from '../result';
import type { GraphModel, Tensor } from '../tfjs/types';
@ -13,15 +13,17 @@ import { fakeOps } from '../tfjs/backend';
import { env } from '../env';
let model: GraphModel | null;
let inputSize = 0;
const cachedBoxes: Array<[number, number, number, number]> = [];
type Keypoints = { score: number, part: string, position: [number, number], positionRaw: [number, number] };
const keypoints: Array<Keypoints> = [];
type Person = { id: number, score: number, box: [number, number, number, number], boxRaw: [number, number, number, number], keypoints: Array<Keypoints> }
type Body = { id: number, score: number, box: [number, number, number, number], boxRaw: [number, number, number, number], keypoints: Array<Keypoints> }
let box: [number, number, number, number] = [0, 0, 0, 0];
let boxRaw: [number, number, number, number] = [0, 0, 0, 0];
let score = 0;
let skipped = Number.MAX_SAFE_INTEGER;
const keypoints: Array<Keypoints> = [];
const bodyParts = ['nose', 'leftEye', 'rightEye', 'leftEar', 'rightEar', 'leftShoulder', 'rightShoulder', 'leftElbow', 'rightElbow', 'leftWrist', 'rightWrist', 'leftHip', 'rightHip', 'leftKnee', 'rightKnee', 'leftAnkle', 'rightAnkle'];
@ -33,25 +35,28 @@ export async function load(config: Config): Promise<GraphModel> {
if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
if (inputSize === -1) inputSize = 256;
return model;
}
async function parseSinglePose(res, config, image) {
keypoints.length = 0;
async function parseSinglePose(res, config, image, inputBox) {
const kpt = res[0][0];
keypoints.length = 0;
for (let id = 0; id < kpt.length; id++) {
score = kpt[id][2];
if (score > config.body.minConfidence) {
const positionRaw: [number, number] = [
(inputBox[3] - inputBox[1]) * kpt[id][1] + inputBox[1],
(inputBox[2] - inputBox[0]) * kpt[id][0] + inputBox[0],
];
keypoints.push({
score: Math.round(100 * score) / 100,
part: bodyParts[id],
positionRaw: [ // normalized to 0..1
kpt[id][1],
kpt[id][0],
],
positionRaw,
position: [ // normalized to input image size
Math.round((image.shape[2] || 0) * kpt[id][1]),
Math.round((image.shape[1] || 0) * kpt[id][0]),
Math.round((image.shape[2] || 0) * positionRaw[0]),
Math.round((image.shape[1] || 0) * positionRaw[1]),
],
});
}
@ -73,13 +78,13 @@ async function parseSinglePose(res, config, image) {
Math.max(...xRaw) - Math.min(...xRaw),
Math.max(...yRaw) - Math.min(...yRaw),
];
const persons: Array<Person> = [];
persons.push({ id: 0, score, box, boxRaw, keypoints });
return persons;
const bodies: Array<Body> = [];
bodies.push({ id: 0, score, box, boxRaw, keypoints });
return bodies;
}
async function parseMultiPose(res, config, image) {
const persons: Array<Person> = [];
async function parseMultiPose(res, config, image, inputBox) {
const bodies: Array<Body> = [];
for (let id = 0; id < res[0].length; id++) {
const kpt = res[0][id];
score = Math.round(100 * kpt[51 + 4]) / 100;
@ -89,16 +94,20 @@ async function parseMultiPose(res, config, image) {
for (let i = 0; i < 17; i++) {
const partScore = Math.round(100 * kpt[3 * i + 2]) / 100;
if (partScore > config.body.minConfidence) {
const positionRaw: [number, number] = [
(inputBox[3] - inputBox[1]) * kpt[3 * i + 1] + inputBox[1],
(inputBox[2] - inputBox[0]) * kpt[3 * i + 0] + inputBox[0],
];
keypoints.push({
part: bodyParts[i],
score: partScore,
positionRaw: [kpt[3 * i + 1], kpt[3 * i + 0]],
position: [Math.trunc(kpt[3 * i + 1] * (image.shape[2] || 0)), Math.trunc(kpt[3 * i + 0] * (image.shape[1] || 0))],
positionRaw,
position: [Math.trunc(positionRaw[0] * (image.shape[2] || 0)), Math.trunc(positionRaw[0] * (image.shape[1] || 0))],
});
}
}
boxRaw = [kpt[51 + 1], kpt[51 + 0], kpt[51 + 3] - kpt[51 + 1], kpt[51 + 2] - kpt[51 + 0]];
persons.push({
bodies.push({
id,
score,
boxRaw,
@ -111,36 +120,50 @@ async function parseMultiPose(res, config, image) {
keypoints: [...keypoints],
});
}
return persons;
return bodies;
}
export async function predict(image: Tensor, config: Config): Promise<BodyResult[]> {
if ((skipped < (config.body.skipFrames || 0)) && config.skipFrame && Object.keys(keypoints).length > 0) {
skipped++;
return [{ id: 0, score, box, boxRaw, keypoints }];
}
skipped = 0;
export async function predict(input: Tensor, config: Config): Promise<BodyResult[]> {
if (!model || !model?.inputs[0].shape) return [];
return new Promise(async (resolve) => {
const tensor = tf.tidy(() => {
if (!model?.inputs[0].shape) return null;
let inputSize = model.inputs[0].shape[2];
if (inputSize === -1) inputSize = 256;
const resize = tf.image.resizeBilinear(image, [inputSize, inputSize], false);
const cast = tf.cast(resize, 'int32');
return cast;
});
const t: Record<string, Tensor> = {};
let resT;
if (config.body.enabled) resT = await model?.predict(tensor);
tf.dispose(tensor);
let bodies: Array<Body> = [];
if (!resT) resolve([]);
const res = await resT.array();
let body;
if (resT.shape[2] === 17) body = await parseSinglePose(res, config, image);
else if (resT.shape[2] === 56) body = await parseMultiPose(res, config, image);
tf.dispose(resT);
if (!config.skipFrame) cachedBoxes.length = 0; // allowed to use cache or not
skipped++;
resolve(body);
for (let i = 0; i < cachedBoxes.length; i++) { // run detection based on cached boxes
t.crop = tf.image.cropAndResize(input, [cachedBoxes[i]], [0], [inputSize, inputSize], 'bilinear');
t.cast = tf.cast(t.crop, 'int32');
t.res = await model?.predict(t.cast) as Tensor;
const res = await t.res.array();
const newBodies = (t.res.shape[2] === 17) ? await parseSinglePose(res, config, input, cachedBoxes[i]) : await parseMultiPose(res, config, input, cachedBoxes[i]);
bodies = bodies.concat(newBodies);
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
}
if ((bodies.length !== config.body.maxDetected) && (skipped > (config.body.skipFrames || 0))) { // run detection on full frame
t.resized = tf.image.resizeBilinear(input, [inputSize, inputSize], false);
t.cast = tf.cast(t.resized, 'int32');
t.res = await model?.predict(t.cast) as Tensor;
const res = await t.res.array();
bodies = (t.res.shape[2] === 17) ? await parseSinglePose(res, config, input, [0, 0, 1, 1]) : await parseMultiPose(res, config, input, [0, 0, 1, 1]);
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
cachedBoxes.length = 0; // reset cache
skipped = 0;
}
if (config.skipFrame) { // create box cache based on last detections
cachedBoxes.length = 0;
for (let i = 0; i < bodies.length; i++) {
if (bodies[i].keypoints.length > 10) { // only update cache if we detected sufficient number of keypoints
const kpts = bodies[i].keypoints.map((kpt) => kpt.position);
const newBox = scaleBox(kpts, 1.5, [input.shape[2], input.shape[1]]);
cachedBoxes.push([...newBox.yxBox]);
}
}
}
resolve(bodies);
});
}

View File

@ -69,3 +69,30 @@ export async function wait(time) {
const waiting = new Promise((resolve) => setTimeout(() => resolve(true), time));
await waiting;
}
// helper function: find box around keypoints, square it and scale it
export function scaleBox(keypoints, boxScaleFact, outputSize) {
const coords = [keypoints.map((pt) => pt[0]), keypoints.map((pt) => pt[1])]; // all x/y coords
const maxmin = [Math.max(...coords[0]), Math.min(...coords[0]), Math.max(...coords[1]), Math.min(...coords[1])]; // find min/max x/y coordinates
const center = [(maxmin[0] + maxmin[1]) / 2, (maxmin[2] + maxmin[3]) / 2]; // find center x and y coord of all fingers
const diff = Math.max(center[0] - maxmin[1], center[1] - maxmin[3], -center[0] + maxmin[0], -center[1] + maxmin[2]) * boxScaleFact; // largest distance from center in any direction
const box = [
Math.trunc(center[0] - diff),
Math.trunc(center[1] - diff),
Math.trunc(2 * diff),
Math.trunc(2 * diff),
] as [number, number, number, number];
const boxRaw = [ // work backwards
box[0] / outputSize[0],
box[1] / outputSize[1],
box[2] / outputSize[0],
box[3] / outputSize[1],
] as [number, number, number, number];
const yxBox = [ // work backwards
boxRaw[1],
boxRaw[0],
boxRaw[3] + boxRaw[1],
boxRaw[2] + boxRaw[0],
] as [number, number, number, number];
return { box, boxRaw, yxBox };
}