human/src/body/blazepose.ts

245 lines
12 KiB
TypeScript
Raw Normal View History

2021-09-27 19:58:13 +02:00
/**
* BlazePose model implementation
*/
2021-10-04 23:03:36 +02:00
import * as tf from '../../dist/tfjs.esm.js';
2022-01-16 15:49:55 +01:00
import { loadModel } from '../tfjs/load';
2021-11-17 02:16:49 +01:00
import { constants } from '../tfjs/constants';
2022-01-17 17:03:21 +01:00
import { log, now } from '../util/util';
2021-12-15 15:26:32 +01:00
import type { BodyKeypoint, BodyResult, BodyLandmark, Box, Point, BodyAnnotation } from '../result';
2021-09-27 19:58:13 +02:00
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
import * as coords from './blazeposecoords';
2021-11-22 20:33:40 +01:00
import * as detect from './blazeposedetector';
import * as box from '../util/box';
2021-09-27 19:58:13 +02:00
const env = { initial: true };
2021-11-22 20:33:40 +01:00
// const models: [GraphModel | null, GraphModel | null] = [null, null];
const models: { detector: GraphModel | null, landmarks: GraphModel | null } = { detector: null, landmarks: null };
const inputSize: { detector: [number, number], landmarks: [number, number] } = { detector: [224, 224], landmarks: [256, 256] };
let skipped = Number.MAX_SAFE_INTEGER;
2021-11-22 20:33:40 +01:00
const outputNodes: { detector: string[], landmarks: string[] } = {
landmarks: ['ld_3d', 'activation_segmentation', 'activation_heatmap', 'world_3d', 'output_poseflag'],
detector: [],
};
let cache: BodyResult | null = null;
let cropBox: Box | undefined;
let padding: [number, number][] = [[0, 0], [0, 0], [0, 0], [0, 0]];
2021-10-23 15:38:52 +02:00
let lastTime = 0;
2021-09-27 19:58:13 +02:00
2021-11-22 20:33:40 +01:00
const sigmoid = (x) => (1 - (1 / (1 + Math.exp(x))));
2021-09-27 19:58:13 +02:00
export async function loadDetect(config: Config): Promise<GraphModel> {
2021-11-22 20:33:40 +01:00
if (env.initial) models.detector = null;
2022-08-21 19:34:51 +02:00
if (!models.detector && config.body['detector'] && config.body['detector'].modelPath || '') {
models.detector = await loadModel(config.body['detector'].modelPath);
2021-11-22 20:33:40 +01:00
const inputs = Object.values(models.detector.modelSignature['inputs']);
inputSize.detector[0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0;
inputSize.detector[1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
} else if (config.debug && models.detector) log('cached model:', models.detector['modelUrl']);
await detect.createAnchors();
return models.detector as GraphModel;
2021-09-27 19:58:13 +02:00
}
export async function loadPose(config: Config): Promise<GraphModel> {
2021-11-22 20:33:40 +01:00
if (env.initial) models.landmarks = null;
if (!models.landmarks) {
2022-01-17 17:03:21 +01:00
models.landmarks = await loadModel(config.body.modelPath);
2021-11-22 20:33:40 +01:00
const inputs = Object.values(models.landmarks.modelSignature['inputs']);
inputSize.landmarks[0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0;
inputSize.landmarks[1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
} else if (config.debug) log('cached model:', models.landmarks['modelUrl']);
return models.landmarks;
2021-09-27 19:58:13 +02:00
}
export async function load(config: Config): Promise<[GraphModel | null, GraphModel | null]> {
2021-11-22 20:33:40 +01:00
if (!models.detector) await loadDetect(config);
if (!models.landmarks) await loadPose(config);
return [models.detector, models.landmarks];
2021-09-27 19:58:13 +02:00
}
async function prepareImage(input: Tensor, size: number): Promise<Tensor> {
const t: Record<string, Tensor> = {};
if (!input.shape || !input.shape[1] || !input.shape[2]) return input;
2021-11-20 00:30:57 +01:00
let final: Tensor;
if (cropBox) {
t.cropped = tf.image.cropAndResize(input, [cropBox], [0], [input.shape[1], input.shape[2]]); // if we have cached box use it to crop input
}
2021-11-20 00:30:57 +01:00
if (input.shape[1] !== input.shape[2]) { // only pad if width different than height
const height: [number, number] = [
input.shape[2] > input.shape[1] ? Math.trunc((input.shape[2] - input.shape[1]) / 2) : 0,
input.shape[2] > input.shape[1] ? Math.trunc((input.shape[2] - input.shape[1]) / 2) : 0,
];
const width: [number, number] = [
input.shape[1] > input.shape[2] ? Math.trunc((input.shape[1] - input.shape[2]) / 2) : 0,
input.shape[1] > input.shape[2] ? Math.trunc((input.shape[1] - input.shape[2]) / 2) : 0,
];
2021-11-20 00:30:57 +01:00
padding = [
[0, 0], // dont touch batch
2021-11-22 20:33:40 +01:00
height, // height before&after
width, // width before&after
2021-11-20 00:30:57 +01:00
[0, 0], // dont touch rbg
];
t.pad = tf.pad(t.cropped || input, padding); // use cropped box if it exists
t.resize = tf.image.resizeBilinear(t.pad, [size, size]);
2021-11-20 00:30:57 +01:00
final = tf.div(t.resize, constants.tf255);
2021-11-22 20:33:40 +01:00
} else if (input.shape[1] !== size) { // if input needs resizing
t.resize = tf.image.resizeBilinear(t.cropped || input, [size, size]);
2021-11-20 00:30:57 +01:00
final = tf.div(t.resize, constants.tf255);
} else { // if input is already in a correct resolution just normalize it
final = tf.div(t.cropped || input, constants.tf255);
2021-11-20 00:30:57 +01:00
}
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
return final;
}
2022-08-21 19:34:51 +02:00
function rescaleKeypoints(keypoints: BodyKeypoint[], outputSize: [number, number]): BodyKeypoint[] {
for (const kpt of keypoints) { // first rescale due to padding
kpt.position = [
2021-11-22 20:33:40 +01:00
Math.trunc(kpt.position[0] * (outputSize[0] + padding[2][0] + padding[2][1]) / outputSize[0] - padding[2][0]),
Math.trunc(kpt.position[1] * (outputSize[1] + padding[1][0] + padding[1][1]) / outputSize[1] - padding[1][0]),
kpt.position[2] as number,
];
2022-01-05 14:34:31 +01:00
kpt.positionRaw = [kpt.position[0] / outputSize[0], kpt.position[1] / outputSize[1], 2 * (kpt.position[2] as number) / (outputSize[0] + outputSize[1])];
2021-09-27 19:58:13 +02:00
}
if (cropBox) { // second rescale due to cropping
for (const kpt of keypoints) {
kpt.positionRaw = [
kpt.positionRaw[0] + cropBox[1], // correct offset due to crop
kpt.positionRaw[1] + cropBox[0], // correct offset due to crop
kpt.positionRaw[2] as number,
];
kpt.position = [
Math.trunc(kpt.positionRaw[0] * outputSize[0]),
Math.trunc(kpt.positionRaw[1] * outputSize[1]),
kpt.positionRaw[2] as number,
];
}
2021-11-22 20:33:40 +01:00
}
return keypoints;
2021-11-22 20:33:40 +01:00
}
2021-10-15 15:34:40 +02:00
2022-08-21 19:34:51 +02:00
async function fixKeypoints(keypoints: BodyKeypoint[]) {
2021-12-29 18:37:46 +01:00
// palm z-coord is incorrect around near-zero so we approximate it
const leftPalm = keypoints.find((k) => k.part === 'leftPalm') as BodyKeypoint;
const leftWrist = keypoints.find((k) => k.part === 'leftWrist') as BodyKeypoint;
const leftIndex = keypoints.find((k) => k.part === 'leftIndex') as BodyKeypoint;
leftPalm.position[2] = ((leftWrist.position[2] || 0) + (leftIndex.position[2] || 0)) / 2;
const rightPalm = keypoints.find((k) => k.part === 'rightPalm') as BodyKeypoint;
const rightWrist = keypoints.find((k) => k.part === 'rightWrist') as BodyKeypoint;
const rightIndex = keypoints.find((k) => k.part === 'rightIndex') as BodyKeypoint;
rightPalm.position[2] = ((rightWrist.position[2] || 0) + (rightIndex.position[2] || 0)) / 2;
}
2021-11-22 20:33:40 +01:00
async function detectLandmarks(input: Tensor, config: Config, outputSize: [number, number]): Promise<BodyResult | null> {
2021-10-15 15:34:40 +02:00
/**
* t.ld: 39 keypoints [x,y,z,score,presence] normalized to input size
* t.segmentation:
* t.heatmap:
* t.world: 39 keypoints [x,y,z] normalized to -1..1
* t.poseflag: body score
2021-11-22 20:33:40 +01:00
*/
const t: Record<string, Tensor> = {};
[t.ld/* 1,195(39*5) */, t.segmentation/* 1,256,256,1 */, t.heatmap/* 1,64,64,39 */, t.world/* 1,117(39*3) */, t.poseflag/* 1,1 */] = models.landmarks?.execute(input, outputNodes.landmarks) as Tensor[]; // run model
const poseScore = (await t.poseflag.data())[0];
2021-09-27 19:58:13 +02:00
const points = await t.ld.data();
2021-12-31 19:58:03 +01:00
const distances = await t.world.data();
2021-11-22 20:33:40 +01:00
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor])); // dont need tensors after this
2022-08-21 19:34:51 +02:00
const keypointsRelative: BodyKeypoint[] = [];
2021-09-27 19:58:13 +02:00
const depth = 5; // each points has x,y,z,visibility,presence
for (let i = 0; i < points.length / depth; i++) {
2021-10-15 15:34:40 +02:00
const score = sigmoid(points[depth * i + 3]);
const presence = sigmoid(points[depth * i + 4]);
const adjScore = Math.trunc(100 * score * presence * poseScore) / 100;
2021-11-22 20:33:40 +01:00
const positionRaw: Point = [points[depth * i + 0] / inputSize.landmarks[0], points[depth * i + 1] / inputSize.landmarks[1], points[depth * i + 2] + 0];
const position: Point = [Math.trunc(outputSize[0] * positionRaw[0]), Math.trunc(outputSize[1] * positionRaw[1]), positionRaw[2] as number];
2021-12-31 19:58:03 +01:00
const distance: Point = [distances[depth * i + 0], distances[depth * i + 1], distances[depth * i + 2] + 0];
keypointsRelative.push({ part: coords.kpt[i] as BodyLandmark, positionRaw, position, distance, score: adjScore });
2021-09-27 19:58:13 +02:00
}
2021-10-15 15:34:40 +02:00
if (poseScore < (config.body.minConfidence || 0)) return null;
2021-12-29 18:37:46 +01:00
fixKeypoints(keypointsRelative);
2022-08-21 19:34:51 +02:00
const keypoints: BodyKeypoint[] = rescaleKeypoints(keypointsRelative, outputSize); // keypoints were relative to input image which is padded
const kpts = keypoints.map((k) => k.position);
const boxes = box.calc(kpts, [outputSize[0], outputSize[1]]); // now find boxes based on rescaled keypoints
2021-12-15 15:26:32 +01:00
const annotations: Record<BodyAnnotation, Point[][]> = {} as Record<BodyAnnotation, Point[][]>;
for (const [name, indexes] of Object.entries(coords.connected)) {
2022-08-21 19:34:51 +02:00
const pt: Point[][] = [];
for (let i = 0; i < indexes.length - 1; i++) {
const pt0 = keypoints.find((kpt) => kpt.part === indexes[i]);
const pt1 = keypoints.find((kpt) => kpt.part === indexes[i + 1]);
2021-11-19 22:11:03 +01:00
if (pt0 && pt1) pt.push([pt0.position, pt1.position]);
}
annotations[name] = pt;
}
const body = { id: 0, score: Math.trunc(100 * poseScore) / 100, box: boxes.box, boxRaw: boxes.boxRaw, keypoints, annotations };
2021-10-15 15:34:40 +02:00
return body;
2021-09-27 19:58:13 +02:00
}
/*
interface DetectedBox { box: Box, boxRaw: Box, score: number }
function rescaleBoxes(boxes: Array<DetectedBox>, outputSize: [number, number]): Array<DetectedBox> {
for (const b of boxes) {
b.box = [
Math.trunc(b.box[0] * (outputSize[0] + padding[2][0] + padding[2][1]) / outputSize[0]),
Math.trunc(b.box[1] * (outputSize[1] + padding[1][0] + padding[1][1]) / outputSize[1]),
Math.trunc(b.box[2] * (outputSize[0] + padding[2][0] + padding[2][1]) / outputSize[0]),
Math.trunc(b.box[3] * (outputSize[1] + padding[1][0] + padding[1][1]) / outputSize[1]),
];
b.boxRaw = [b.box[0] / outputSize[0], b.box[1] / outputSize[1], b.box[2] / outputSize[0], b.box[3] / outputSize[1]];
}
return boxes;
}
2021-11-22 20:33:40 +01:00
async function detectBoxes(input: Tensor, config: Config, outputSize: [number, number]) {
const t: Record<string, Tensor> = {};
t.res = models.detector?.execute(input, ['Identity']) as Tensor; //
t.logitsRaw = tf.slice(t.res, [0, 0, 0], [1, -1, 1]);
t.boxesRaw = tf.slice(t.res, [0, 0, 1], [1, -1, -1]);
t.logits = tf.squeeze(t.logitsRaw);
t.boxes = tf.squeeze(t.boxesRaw);
const boxes = await detect.decode(t.boxes, t.logits, config, outputSize);
rescaleBoxes(boxes, outputSize);
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
return boxes;
}
*/
2021-11-22 20:33:40 +01:00
2021-09-27 19:58:13 +02:00
export async function predict(input: Tensor, config: Config): Promise<BodyResult[]> {
const outputSize: [number, number] = [input.shape[2] || 0, input.shape[1] || 0];
2021-10-23 15:38:52 +02:00
const skipTime = (config.body.skipTime || 0) > (now() - lastTime);
const skipFrame = skipped < (config.body.skipFrames || 0);
if (config.skipAllowed && skipTime && skipFrame && cache !== null) {
2021-09-27 19:58:13 +02:00
skipped++;
} else {
2021-11-22 20:33:40 +01:00
const t: Record<string, Tensor> = {};
/*
2021-11-22 20:33:40 +01:00
if (config.body['detector'] && config.body['detector']['enabled']) {
t.detector = await prepareImage(input, 224);
const boxes = await detectBoxes(t.detector, config, outputSize);
}
*/
t.landmarks = await prepareImage(input, 256); // padded and resized
cache = await detectLandmarks(t.landmarks, config, outputSize);
/*
cropBox = [0, 0, 1, 1]; // reset crop coordinates
if (cache?.boxRaw && config.skipAllowed) {
const cx = (2.0 * cache.boxRaw[0] + cache.boxRaw[2]) / 2;
const cy = (2.0 * cache.boxRaw[1] + cache.boxRaw[3]) / 2;
let size = cache.boxRaw[2] > cache.boxRaw[3] ? cache.boxRaw[2] : cache.boxRaw[3];
size = (size * 1.0) / 2; // enlarge and half it
if (cx > 0.1 && cx < 0.9 && cy > 0.1 && cy < 0.9 && size > 0.1) { // only update if box is sane
const y = 0; // cy - size;
const x = cx - size;
cropBox = [y, x, y + 1, x + 1]; // [y0,x0,y1,x1] used for cropping but width/height are not yet implemented so we only reposition image to center of body
2021-11-22 20:33:40 +01:00
}
}
*/
2021-11-22 20:33:40 +01:00
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
2021-10-23 15:38:52 +02:00
lastTime = now();
2021-09-27 19:58:13 +02:00
skipped = 0;
}
return cache ? [cache] : [];
2021-09-27 19:58:13 +02:00
}