2021-05-25 14:58:20 +02:00
|
|
|
/**
|
|
|
|
* PoseNet module entry point
|
|
|
|
*/
|
|
|
|
|
2021-04-09 14:07:58 +02:00
|
|
|
import { log, join } from '../helpers';
|
2020-11-18 14:26:28 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
2021-04-24 22:04:49 +02:00
|
|
|
import * as poses from './poses';
|
2021-04-24 17:49:26 +02:00
|
|
|
import * as util from './utils';
|
2021-05-22 18:33:19 +02:00
|
|
|
import { Body } from '../result';
|
2021-05-23 03:47:59 +02:00
|
|
|
import { Tensor, GraphModel } from '../tfjs/types';
|
2020-10-12 01:22:43 +02:00
|
|
|
|
2021-05-23 03:47:59 +02:00
|
|
|
let model: GraphModel;
|
2021-04-24 22:04:49 +02:00
|
|
|
const poseNetOutputs = ['MobilenetV1/offset_2/BiasAdd'/* offsets */, 'MobilenetV1/heatmap_2/BiasAdd'/* heatmapScores */, 'MobilenetV1/displacement_fwd_2/BiasAdd'/* displacementFwd */, 'MobilenetV1/displacement_bwd_2/BiasAdd'/* displacementBwd */];
|
|
|
|
|
2021-05-22 18:33:19 +02:00
|
|
|
export async function predict(input, config): Promise<Body[]> {
|
2021-04-24 22:04:49 +02:00
|
|
|
const res = tf.tidy(() => {
|
2021-05-23 03:47:59 +02:00
|
|
|
if (!model.inputs[0].shape) return [];
|
2021-04-24 22:04:49 +02:00
|
|
|
const resized = input.resizeBilinear([model.inputs[0].shape[2], model.inputs[0].shape[1]]);
|
|
|
|
const normalized = resized.toFloat().div(127.5).sub(1.0);
|
2021-05-23 03:47:59 +02:00
|
|
|
const results: Array<Tensor> = model.execute(normalized, poseNetOutputs) as Array<Tensor>;
|
2021-04-24 22:04:49 +02:00
|
|
|
const results3d = results.map((y) => y.squeeze([0]));
|
|
|
|
results3d[1] = results3d[1].sigmoid(); // apply sigmoid on scores
|
|
|
|
return results3d;
|
2020-12-17 00:36:24 +01:00
|
|
|
});
|
|
|
|
|
2021-04-24 22:04:49 +02:00
|
|
|
const buffers = await Promise.all(res.map((tensor) => tensor.buffer()));
|
|
|
|
for (const t of res) t.dispose();
|
2020-12-17 00:36:24 +01:00
|
|
|
|
2021-04-25 19:16:04 +02:00
|
|
|
const decoded = await poses.decode(buffers[0], buffers[1], buffers[2], buffers[3], config.body.maxDetected, config.body.minConfidence);
|
2021-05-23 03:47:59 +02:00
|
|
|
if (!model.inputs[0].shape) return [];
|
2021-05-22 18:33:19 +02:00
|
|
|
const scaled = util.scalePoses(decoded, [input.shape[1], input.shape[2]], [model.inputs[0].shape[2], model.inputs[0].shape[1]]) as Body[];
|
2021-04-24 22:04:49 +02:00
|
|
|
return scaled;
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|
2020-11-08 18:26:45 +01:00
|
|
|
|
2021-02-08 17:39:09 +01:00
|
|
|
export async function load(config) {
|
2021-04-12 14:29:52 +02:00
|
|
|
if (!model) {
|
2021-05-23 03:47:59 +02:00
|
|
|
// @ts-ignore type mismatch for GraphModel
|
2021-04-12 14:29:52 +02:00
|
|
|
model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath));
|
2021-05-23 03:47:59 +02:00
|
|
|
if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath);
|
|
|
|
else if (config.debug) log('load model:', model['modelUrl']);
|
|
|
|
} else if (config.debug) log('cached model:', model['modelUrl']);
|
2021-04-24 22:04:49 +02:00
|
|
|
return model;
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|