human/src/blazepose/blazepose.ts

68 lines
3.2 KiB
TypeScript
Raw Normal View History

2021-03-04 16:33:08 +01:00
import { log } from '../log';
import * as tf from '../../dist/tfjs.esm.js';
import * as profile from '../profile';
2021-03-05 13:39:37 +01:00
import * as annotations from './annotations';
2021-03-04 16:33:08 +01:00
let model;
export async function load(config) {
if (!model) {
model = await tf.loadGraphModel(config.body.modelPath);
2021-03-05 14:03:00 +01:00
// blazepose inputSize is 256x256px, but we can find that out dynamically
2021-03-04 16:33:08 +01:00
model.width = parseInt(model.signature.inputs['input_1:0'].tensorShape.dim[2].size);
model.height = parseInt(model.signature.inputs['input_1:0'].tensorShape.dim[1].size);
if (config.debug) log(`load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`);
}
return model;
}
export async function predict(image, config) {
if (!model) return null;
if (!config.body.enabled) return null;
const imgSize = { width: image.shape[2], height: image.shape[1] };
const resize = tf.image.resizeBilinear(image, [model.width || config.body.inputSize, model.height || config.body.inputSize], false);
const normalize = tf.div(resize, [255.0]);
resize.dispose();
let points;
2021-03-05 14:03:00 +01:00
if (!config.profile) { // run through profiler or just execute
2021-03-04 16:33:08 +01:00
const resT = await model.predict(normalize);
2021-03-05 20:30:09 +01:00
// const segmentationT = resT.find((t) => (t.size === 16384 || t.size === 0)).squeeze();
// const segmentation = segmentationT.arraySync(); // array 128 x 128
// tf.dispose(segmentationT);
2021-03-05 14:03:00 +01:00
points = resT.find((t) => (t.size === 195 || t.size === 155)).dataSync(); // order of output tensors may change between models, full has 195 and upper has 155 items
2021-03-04 16:33:08 +01:00
resT.forEach((t) => t.dispose());
} else {
const profileData = await tf.profile(() => model.predict(normalize));
2021-03-05 14:03:00 +01:00
points = profileData.result.find((t) => (t.size === 195 || t.size === 155)).dataSync();
2021-03-04 16:33:08 +01:00
profileData.result.forEach((t) => t.dispose());
profile.run('blazepose', profileData);
}
normalize.dispose();
const keypoints: Array<{ id, part, position: { x, y, z }, score, presence }> = [];
2021-03-05 14:03:00 +01:00
const labels = points.length === 195 ? annotations.full : annotations.upper; // full model has 39 keypoints, upper has 31 keypoints
const depth = 5; // each points has x,y,z,visibility,presence
2021-03-05 13:39:37 +01:00
for (let i = 0; i < points.length / depth; i++) {
2021-03-04 16:33:08 +01:00
keypoints.push({
id: i,
part: labels[i],
position: {
2021-03-05 14:03:00 +01:00
x: Math.trunc(imgSize.width * points[depth * i + 0] / 255), // return normalized x value istead of 0..255
y: Math.trunc(imgSize.height * points[depth * i + 1] / 255), // return normalized y value istead of 0..255
2021-03-05 13:39:37 +01:00
z: Math.trunc(points[depth * i + 2]) + 0, // fix negative zero
2021-03-04 16:33:08 +01:00
},
2021-03-05 13:39:37 +01:00
score: (100 - Math.trunc(100 / (1 + Math.exp(points[depth * i + 3])))) / 100, // reverse sigmoid value
presence: (100 - Math.trunc(100 / (1 + Math.exp(points[depth * i + 4])))) / 100, // reverse sigmoid value
2021-03-04 16:33:08 +01:00
});
}
// console.log('POINTS', imgSize, pts.length, pts);
return [{ keypoints }];
}
/*
2021-03-05 14:03:00 +01:00
Model card:
- https://drive.google.com/file/d/10IU-DRP2ioSNjKFdiGbmmQX81xAYj88s/view
Download:
- https://github.com/PINTO0309/PINTO_model_zoo/tree/main/058_BlazePose_Full_Keypoints/10_new_256x256/saved_model/tfjs_model_float16
- https://github.com/PINTO0309/PINTO_model_zoo/tree/main/053_BlazePose/20_new_256x256/saved_model/tfjs_model_float16
2021-03-04 16:33:08 +01:00
*/