human/src/body/efficientpose.ts

127 lines
5.0 KiB
TypeScript
Raw Normal View History

2021-05-25 14:58:20 +02:00
/**
* EfficientPose model implementation
*
* Based on: [**EfficientPose**](https://github.com/daniegr/EfficientPose)
2021-05-25 14:58:20 +02:00
*/
2021-09-27 19:58:13 +02:00
import { log, join } from '../util/util';
2021-03-26 23:50:19 +01:00
import * as tf from '../../dist/tfjs.esm.js';
import * as coords from './efficientposecoords';
import type { BodyKeypoint, BodyResult, Box, Point } from '../result';
2021-09-13 19:28:35 +02:00
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
2021-09-27 19:58:13 +02:00
import { env } from '../util/env';
2021-03-26 23:50:19 +01:00
2021-09-17 17:23:00 +02:00
let model: GraphModel | null;
2021-05-22 20:53:51 +02:00
const keypoints: Array<BodyKeypoint> = [];
2021-09-27 15:19:43 +02:00
let box: Box = [0, 0, 0, 0];
let boxRaw: Box = [0, 0, 0, 0];
let score = 0;
2021-03-26 23:50:19 +01:00
let skipped = Number.MAX_SAFE_INTEGER;
2021-06-03 15:41:53 +02:00
export async function load(config: Config): Promise<GraphModel> {
2021-09-17 17:23:00 +02:00
if (env.initial) model = null;
2021-03-26 23:50:19 +01:00
if (!model) {
2021-09-12 05:54:35 +02:00
model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
2021-03-26 23:50:19 +01:00
return model;
}
// performs argmax and max functions on a 2d tensor
function max2d(inputs, minScore) {
const [width, height] = inputs.shape;
return tf.tidy(() => {
2021-08-14 17:16:26 +02:00
const mod = (a, b) => tf.sub(a, tf.mul(tf.div(a, tf.scalar(b, 'int32')), tf.scalar(b, 'int32'))); // modulus op implemented in tf
const reshaped = tf.reshape(inputs, [height * width]); // combine all data
const newScore = tf.max(reshaped, 0).dataSync()[0]; // get highest score // inside tf.tidy
if (newScore > minScore) { // skip coordinate calculation is score is too low
const coordinates = tf.argMax(reshaped, 0);
const x = mod(coordinates, width).dataSync()[0]; // inside tf.tidy
const y = tf.div(coordinates, tf.scalar(width, 'int32')).dataSync()[0]; // inside tf.tidy
return [x, y, newScore];
2021-03-26 23:50:19 +01:00
}
return [0, 0, newScore];
2021-03-26 23:50:19 +01:00
});
}
2021-09-12 05:54:35 +02:00
export async function predict(image: Tensor, config: Config): Promise<BodyResult[]> {
if ((skipped < (config.body?.skipFrames || 0)) && config.skipFrame && Object.keys(keypoints).length > 0) {
2021-03-26 23:50:19 +01:00
skipped++;
return [{ id: 0, score, box, boxRaw, keypoints, annotations: {} }];
2021-03-26 23:50:19 +01:00
}
skipped = 0;
2021-03-26 23:50:19 +01:00
return new Promise(async (resolve) => {
2021-03-27 20:43:48 +01:00
const tensor = tf.tidy(() => {
2021-09-17 17:23:00 +02:00
if (!model?.inputs[0].shape) return null;
2021-03-27 20:43:48 +01:00
const resize = tf.image.resizeBilinear(image, [model.inputs[0].shape[2], model.inputs[0].shape[1]], false);
const enhance = tf.mul(resize, 2);
const norm = enhance.sub(1);
return norm;
});
2021-03-26 23:50:19 +01:00
let resT;
2021-09-17 17:23:00 +02:00
if (config.body.enabled) resT = await model?.predict(tensor);
2021-07-29 22:06:03 +02:00
tf.dispose(tensor);
2021-03-26 23:50:19 +01:00
if (resT) {
keypoints.length = 0;
2021-03-26 23:50:19 +01:00
const squeeze = resT.squeeze();
tf.dispose(resT);
// body parts are basically just a stack of 2d tensors
const stack = squeeze.unstack(2);
tf.dispose(squeeze);
// process each unstacked tensor as a separate body part
for (let id = 0; id < stack.length; id++) {
// actual processing to get coordinates and score
const [x, y, partScore] = max2d(stack[id], config.body.minConfidence);
2021-09-12 05:54:35 +02:00
if (score > (config.body?.minConfidence || 0)) {
keypoints.push({
score: Math.round(100 * partScore) / 100,
part: coords.kpt[id],
2021-06-01 14:59:09 +02:00
positionRaw: [ // normalized to 0..1
// @ts-ignore model is not undefined here
2021-06-01 14:59:09 +02:00
x / model.inputs[0].shape[2], y / model.inputs[0].shape[1],
],
position: [ // normalized to input image size
// @ts-ignore model is not undefined here
2021-06-01 14:59:09 +02:00
Math.round(image.shape[2] * x / model.inputs[0].shape[2]), Math.round(image.shape[1] * y / model.inputs[0].shape[1]),
],
2021-03-26 23:50:19 +01:00
});
}
}
stack.forEach((s) => tf.dispose(s));
}
score = keypoints.reduce((prev, curr) => (curr.score > prev ? curr.score : prev), 0);
2021-06-01 14:59:09 +02:00
const x = keypoints.map((a) => a.position[0]);
const y = keypoints.map((a) => a.position[1]);
box = [
Math.min(...x),
Math.min(...y),
Math.max(...x) - Math.min(...x),
Math.max(...y) - Math.min(...y),
];
2021-06-01 14:59:09 +02:00
const xRaw = keypoints.map((a) => a.positionRaw[0]);
const yRaw = keypoints.map((a) => a.positionRaw[1]);
boxRaw = [
Math.min(...xRaw),
Math.min(...yRaw),
Math.max(...xRaw) - Math.min(...xRaw),
Math.max(...yRaw) - Math.min(...yRaw),
];
const annotations: Record<string, Point[][]> = {};
for (const [name, indexes] of Object.entries(coords.connected)) {
const pt: Array<Point[]> = [];
for (let i = 0; i < indexes.length - 1; i++) {
const pt0 = keypoints.find((kpt) => kpt.part === indexes[i]);
const pt1 = keypoints.find((kpt) => kpt.part === indexes[i + 1]);
if (pt0 && pt1 && pt0.score > (config.body.minConfidence || 0) && pt1.score > (config.body.minConfidence || 0)) pt.push([pt0.position, pt1.position]);
}
annotations[name] = pt;
}
resolve([{ id: 0, score, box, boxRaw, keypoints, annotations }]);
2021-03-26 23:50:19 +01:00
});
}