human/src/body/movenetfix.ts

108 lines
5.1 KiB
TypeScript
Raw Normal View History

2021-10-14 18:26:59 +02:00
import type { BodyKeypoint, BodyResult } from '../result';
import * as box from '../util/box';
import * as coords from './movenetcoords';
import * as tf from '../../dist/tfjs.esm.js';
import type { Tensor } from '../tfjs/types';
const maxJitter = 0.005; // default allowed jitter is within 0.5%
const cache: {
keypoints: Array<BodyKeypoint>,
padding: [number, number][];
} = {
keypoints: [],
padding: [[0, 0], [0, 0], [0, 0], [0, 0]],
};
export function bodyParts(body: BodyResult) { // model sometimes mixes up left vs right keypoints so we fix them
for (const pair of coords.horizontal) { // fix body parts left vs right
const left = body.keypoints.findIndex((kp) => kp.part === pair[0]);
const right = body.keypoints.findIndex((kp) => kp.part === pair[1]);
if (body.keypoints[left] && body.keypoints[right]) {
if (body.keypoints[left].position[0] < body.keypoints[right].position[0]) {
const tmp = body.keypoints[left];
body.keypoints[left] = body.keypoints[right];
body.keypoints[right] = tmp;
}
}
}
for (const pair of coords.vertical) { // remove body parts with improbable vertical position
const lower = body.keypoints.findIndex((kp) => (kp && kp.part === pair[0]));
const higher = body.keypoints.findIndex((kp) => (kp && kp.part === pair[1]));
if (body.keypoints[lower] && body.keypoints[higher]) {
if (body.keypoints[lower].position[1] < body.keypoints[higher].position[1]) {
body.keypoints.splice(lower, 1);
}
}
}
for (const [pair, compare] of coords.relative) { // rearrange body parts according to their relative position
const left = body.keypoints.findIndex((kp) => (kp && kp.part === pair[0]));
const right = body.keypoints.findIndex((kp) => (kp && kp.part === pair[1]));
const leftTo = body.keypoints.findIndex((kp) => (kp && kp.part === compare[0]));
const rightTo = body.keypoints.findIndex((kp) => (kp && kp.part === compare[1]));
if (!body.keypoints[leftTo] || !body.keypoints[rightTo]) continue; // only if we have both compare points
const distanceLeft = body.keypoints[left] ? [
Math.abs(body.keypoints[leftTo].position[0] - body.keypoints[left].position[0]),
Math.abs(body.keypoints[rightTo].position[0] - body.keypoints[left].position[0]),
] : [0, 0];
const distanceRight = body.keypoints[right] ? [
Math.abs(body.keypoints[rightTo].position[0] - body.keypoints[right].position[0]),
Math.abs(body.keypoints[leftTo].position[0] - body.keypoints[right].position[0]),
] : [0, 0];
if (distanceLeft[0] > distanceLeft[1] || distanceRight[0] > distanceRight[1]) { // should flip keypoints
const tmp = body.keypoints[left];
body.keypoints[left] = body.keypoints[right];
body.keypoints[right] = tmp;
}
}
}
export function jitter(keypoints: Array<BodyKeypoint>): Array<BodyKeypoint> {
for (let i = 0; i < keypoints.length; i++) {
if (keypoints[i] && cache.keypoints[i]) {
const diff = [Math.abs(keypoints[i].positionRaw[0] - cache.keypoints[i].positionRaw[0]), Math.abs(keypoints[i].positionRaw[1] - cache.keypoints[i].positionRaw[1])];
if (diff[0] < maxJitter && diff[1] < maxJitter) {
keypoints[i] = cache.keypoints[i]; // below jitter so replace keypoint
} else {
cache.keypoints[i] = keypoints[i]; // above jitter so update cache
}
} else {
cache.keypoints[i] = keypoints[i]; // cache for keypoint doesnt exist so create it here
}
}
return keypoints;
}
export function padInput(input: Tensor, inputSize: number): Tensor {
const t: Record<string, Tensor> = {};
if (!input.shape || !input.shape[1] || !input.shape[2]) return input;
cache.padding = [
[0, 0], // dont touch batch
[input.shape[2] > input.shape[1] ? Math.trunc((input.shape[2] - input.shape[1]) / 2) : 0, input.shape[2] > input.shape[1] ? Math.trunc((input.shape[2] - input.shape[1]) / 2) : 0], // height before&after
[input.shape[1] > input.shape[2] ? Math.trunc((input.shape[1] - input.shape[2]) / 2) : 0, input.shape[1] > input.shape[2] ? Math.trunc((input.shape[1] - input.shape[2]) / 2) : 0], // width before&after
[0, 0], // dont touch rbg
];
t.pad = tf.pad(input, cache.padding);
t.resize = tf.image.resizeBilinear(t.pad, [inputSize, inputSize]);
const final = tf.cast(t.resize, 'int32');
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
return final;
}
export function rescaleBody(body: BodyResult, outputSize: [number, number]): BodyResult {
body.keypoints = body.keypoints.filter((kpt) => kpt && kpt.position); // filter invalid keypoints
for (const kpt of body.keypoints) {
kpt.position = [
kpt.position[0] * (outputSize[0] + cache.padding[2][0] + cache.padding[2][1]) / outputSize[0] - cache.padding[2][0],
kpt.position[1] * (outputSize[1] + cache.padding[1][0] + cache.padding[1][1]) / outputSize[1] - cache.padding[1][0],
];
kpt.positionRaw = [
kpt.position[0] / outputSize[0], kpt.position[1] / outputSize[1],
];
}
const rescaledBoxes = box.calc(body.keypoints.map((pt) => pt.position), outputSize);
body.box = rescaledBoxes.box;
body.boxRaw = rescaledBoxes.boxRaw;
return body;
}