2021-05-25 14:58:20 +02:00
|
|
|
/**
|
|
|
|
* EfficientPose Module
|
|
|
|
*/
|
|
|
|
|
2021-04-09 14:07:58 +02:00
|
|
|
import { log, join } from '../helpers';
|
2021-03-26 23:50:19 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
2021-05-22 18:33:19 +02:00
|
|
|
import { Body } from '../result';
|
2021-06-03 15:41:53 +02:00
|
|
|
import { GraphModel, Tensor } from '../tfjs/types';
|
|
|
|
import { Config } from '../config';
|
2021-03-26 23:50:19 +01:00
|
|
|
|
2021-05-23 03:47:59 +02:00
|
|
|
let model: GraphModel;
|
2021-05-22 20:53:51 +02:00
|
|
|
|
2021-06-01 14:59:09 +02:00
|
|
|
type Keypoints = { score: number, part: string, position: [number, number], positionRaw: [number, number] };
|
2021-05-22 20:53:51 +02:00
|
|
|
|
2021-05-23 19:52:49 +02:00
|
|
|
const keypoints: Array<Keypoints> = [];
|
|
|
|
let box: [number, number, number, number] = [0, 0, 0, 0];
|
|
|
|
let boxRaw: [number, number, number, number] = [0, 0, 0, 0];
|
|
|
|
let score = 0;
|
2021-03-26 23:50:19 +01:00
|
|
|
let skipped = Number.MAX_SAFE_INTEGER;
|
|
|
|
|
|
|
|
const bodyParts = ['head', 'neck', 'rightShoulder', 'rightElbow', 'rightWrist', 'chest', 'leftShoulder', 'leftElbow', 'leftWrist', 'pelvis', 'rightHip', 'rightKnee', 'rightAnkle', 'leftHip', 'leftKnee', 'leftAnkle'];
|
|
|
|
|
2021-06-03 15:41:53 +02:00
|
|
|
export async function load(config: Config): Promise<GraphModel> {
|
2021-03-26 23:50:19 +01:00
|
|
|
if (!model) {
|
2021-05-23 03:47:59 +02:00
|
|
|
// @ts-ignore type mismatch on GraphModel
|
2021-04-09 14:07:58 +02:00
|
|
|
model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath));
|
2021-05-23 03:47:59 +02:00
|
|
|
if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath);
|
|
|
|
else if (config.debug) log('load model:', model['modelUrl']);
|
|
|
|
} else if (config.debug) log('cached model:', model['modelUrl']);
|
2021-03-26 23:50:19 +01:00
|
|
|
return model;
|
|
|
|
}
|
|
|
|
|
|
|
|
// performs argmax and max functions on a 2d tensor
|
|
|
|
function max2d(inputs, minScore) {
|
|
|
|
const [width, height] = inputs.shape;
|
|
|
|
return tf.tidy(() => {
|
|
|
|
// modulus op implemented in tf
|
|
|
|
const mod = (a, b) => tf.sub(a, tf.mul(tf.div(a, tf.scalar(b, 'int32')), tf.scalar(b, 'int32')));
|
|
|
|
// combine all data
|
|
|
|
const reshaped = tf.reshape(inputs, [height * width]);
|
|
|
|
// get highest score
|
2021-08-12 15:31:16 +02:00
|
|
|
const newScore = tf.max(reshaped, 0).dataSync()[0]; // inside tf.tidy
|
2021-05-23 19:52:49 +02:00
|
|
|
if (newScore > minScore) {
|
2021-03-26 23:50:19 +01:00
|
|
|
// skip coordinate calculation is score is too low
|
|
|
|
const coords = tf.argMax(reshaped, 0);
|
2021-08-12 15:31:16 +02:00
|
|
|
const x = mod(coords, width).dataSync()[0]; // inside tf.tidy
|
|
|
|
const y = tf.div(coords, tf.scalar(width, 'int32')).dataSync()[0]; // inside tf.tidy
|
2021-05-23 19:52:49 +02:00
|
|
|
return [x, y, newScore];
|
2021-03-26 23:50:19 +01:00
|
|
|
}
|
2021-05-23 19:52:49 +02:00
|
|
|
return [0, 0, newScore];
|
2021-03-26 23:50:19 +01:00
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-06-03 15:41:53 +02:00
|
|
|
export async function predict(image: Tensor, config: Config): Promise<Body[]> {
|
2021-05-18 17:26:16 +02:00
|
|
|
if ((skipped < config.body.skipFrames) && config.skipFrame && Object.keys(keypoints).length > 0) {
|
2021-03-26 23:50:19 +01:00
|
|
|
skipped++;
|
2021-05-23 19:52:49 +02:00
|
|
|
return [{ id: 0, score, box, boxRaw, keypoints }];
|
2021-03-26 23:50:19 +01:00
|
|
|
}
|
2021-05-18 17:26:16 +02:00
|
|
|
skipped = 0;
|
2021-03-26 23:50:19 +01:00
|
|
|
return new Promise(async (resolve) => {
|
2021-03-27 20:43:48 +01:00
|
|
|
const tensor = tf.tidy(() => {
|
2021-05-23 03:47:59 +02:00
|
|
|
if (!model.inputs[0].shape) return null;
|
2021-03-27 20:43:48 +01:00
|
|
|
const resize = tf.image.resizeBilinear(image, [model.inputs[0].shape[2], model.inputs[0].shape[1]], false);
|
|
|
|
const enhance = tf.mul(resize, 2);
|
|
|
|
const norm = enhance.sub(1);
|
|
|
|
return norm;
|
|
|
|
});
|
2021-03-26 23:50:19 +01:00
|
|
|
|
|
|
|
let resT;
|
2021-04-25 19:16:04 +02:00
|
|
|
if (config.body.enabled) resT = await model.predict(tensor);
|
2021-07-29 22:06:03 +02:00
|
|
|
tf.dispose(tensor);
|
2021-03-26 23:50:19 +01:00
|
|
|
|
|
|
|
if (resT) {
|
2021-05-23 19:52:49 +02:00
|
|
|
keypoints.length = 0;
|
2021-03-26 23:50:19 +01:00
|
|
|
const squeeze = resT.squeeze();
|
|
|
|
tf.dispose(resT);
|
|
|
|
// body parts are basically just a stack of 2d tensors
|
|
|
|
const stack = squeeze.unstack(2);
|
|
|
|
tf.dispose(squeeze);
|
|
|
|
// process each unstacked tensor as a separate body part
|
|
|
|
for (let id = 0; id < stack.length; id++) {
|
|
|
|
// actual processing to get coordinates and score
|
2021-05-23 19:52:49 +02:00
|
|
|
const [x, y, partScore] = max2d(stack[id], config.body.minConfidence);
|
2021-04-25 19:16:04 +02:00
|
|
|
if (score > config.body.minConfidence) {
|
2021-05-23 19:52:49 +02:00
|
|
|
keypoints.push({
|
|
|
|
score: Math.round(100 * partScore) / 100,
|
2021-03-26 23:50:19 +01:00
|
|
|
part: bodyParts[id],
|
2021-06-01 14:59:09 +02:00
|
|
|
positionRaw: [ // normalized to 0..1
|
2021-05-23 03:47:59 +02:00
|
|
|
// @ts-ignore model is not undefined here
|
2021-06-01 14:59:09 +02:00
|
|
|
x / model.inputs[0].shape[2], y / model.inputs[0].shape[1],
|
|
|
|
],
|
|
|
|
position: [ // normalized to input image size
|
2021-05-23 03:47:59 +02:00
|
|
|
// @ts-ignore model is not undefined here
|
2021-06-01 14:59:09 +02:00
|
|
|
Math.round(image.shape[2] * x / model.inputs[0].shape[2]), Math.round(image.shape[1] * y / model.inputs[0].shape[1]),
|
|
|
|
],
|
2021-03-26 23:50:19 +01:00
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
stack.forEach((s) => tf.dispose(s));
|
|
|
|
}
|
2021-05-23 19:52:49 +02:00
|
|
|
score = keypoints.reduce((prev, curr) => (curr.score > prev ? curr.score : prev), 0);
|
2021-06-01 14:59:09 +02:00
|
|
|
const x = keypoints.map((a) => a.position[0]);
|
|
|
|
const y = keypoints.map((a) => a.position[1]);
|
2021-05-23 19:52:49 +02:00
|
|
|
box = [
|
|
|
|
Math.min(...x),
|
|
|
|
Math.min(...y),
|
|
|
|
Math.max(...x) - Math.min(...x),
|
2021-05-29 15:20:01 +02:00
|
|
|
Math.max(...y) - Math.min(...y),
|
2021-05-23 19:52:49 +02:00
|
|
|
];
|
2021-06-01 14:59:09 +02:00
|
|
|
const xRaw = keypoints.map((a) => a.positionRaw[0]);
|
|
|
|
const yRaw = keypoints.map((a) => a.positionRaw[1]);
|
2021-05-23 19:52:49 +02:00
|
|
|
boxRaw = [
|
|
|
|
Math.min(...xRaw),
|
|
|
|
Math.min(...yRaw),
|
|
|
|
Math.max(...xRaw) - Math.min(...xRaw),
|
2021-05-29 15:20:01 +02:00
|
|
|
Math.max(...yRaw) - Math.min(...yRaw),
|
2021-05-23 19:52:49 +02:00
|
|
|
];
|
|
|
|
resolve([{ id: 0, score, box, boxRaw, keypoints }]);
|
2021-03-26 23:50:19 +01:00
|
|
|
});
|
|
|
|
}
|