2021-04-09 14:07:58 +02:00
|
|
|
import { log, join } from '../helpers';
|
2020-11-18 14:26:28 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
2021-04-24 17:49:26 +02:00
|
|
|
import * as posenetModel from './posenetModel';
|
2020-11-10 02:13:38 +01:00
|
|
|
import * as decodeMultiple from './decodeMultiple';
|
2021-04-24 17:49:26 +02:00
|
|
|
import * as decodeSingle from './decodeSingle';
|
|
|
|
import * as util from './utils';
|
2020-10-12 01:22:43 +02:00
|
|
|
|
2021-04-12 14:29:52 +02:00
|
|
|
let model;
|
|
|
|
|
2021-03-11 16:26:14 +01:00
|
|
|
async function estimateMultiple(input, res, config, inputSize) {
|
2021-04-24 17:49:26 +02:00
|
|
|
const toTensorBuffers3D = (tensors) => Promise.all(tensors.map((tensor) => tensor.buffer()));
|
|
|
|
|
2020-12-17 00:36:24 +01:00
|
|
|
return new Promise(async (resolve) => {
|
2021-04-24 17:49:26 +02:00
|
|
|
const allTensorBuffers = await toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
|
2020-12-17 00:36:24 +01:00
|
|
|
const scoresBuffer = allTensorBuffers[0];
|
|
|
|
const offsetsBuffer = allTensorBuffers[1];
|
|
|
|
const displacementsFwdBuffer = allTensorBuffers[2];
|
|
|
|
const displacementsBwdBuffer = allTensorBuffers[3];
|
2021-03-11 16:26:14 +01:00
|
|
|
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config.body.nmsRadius, config.body.maxDetections, config.body.scoreThreshold);
|
2021-04-24 17:49:26 +02:00
|
|
|
const scaled = util.scalePoses(poses, [input.shape[1], input.shape[2]], [inputSize, inputSize]);
|
2020-12-17 00:36:24 +01:00
|
|
|
resolve(scaled);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-03-11 16:26:14 +01:00
|
|
|
async function estimateSingle(input, res, config, inputSize) {
|
2020-12-17 00:36:24 +01:00
|
|
|
return new Promise(async (resolve) => {
|
2021-04-24 17:49:26 +02:00
|
|
|
const pose = await decodeSingle.decodeSinglePose(res.heatmapScores, res.offsets, config.body.scoreThreshold);
|
|
|
|
const scaled = util.scalePoses([pose], [input.shape[1], input.shape[2]], [inputSize, inputSize]);
|
2020-12-17 00:36:24 +01:00
|
|
|
resolve(scaled);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-02-08 17:39:09 +01:00
|
|
|
export class PoseNet {
|
|
|
|
baseModel: any;
|
2021-03-11 16:26:14 +01:00
|
|
|
inputSize: number
|
2021-04-24 17:49:26 +02:00
|
|
|
constructor(baseModel) {
|
|
|
|
this.baseModel = baseModel;
|
|
|
|
this.inputSize = baseModel.model.inputs[0].shape[1];
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|
|
|
|
|
2020-10-14 17:43:33 +02:00
|
|
|
async estimatePoses(input, config) {
|
2021-04-24 17:49:26 +02:00
|
|
|
const res = this.baseModel.predict(input, config);
|
2020-12-17 00:36:24 +01:00
|
|
|
|
2021-03-11 16:26:14 +01:00
|
|
|
const poses = (config.body.maxDetections < 2)
|
|
|
|
? await estimateSingle(input, res, config, this.inputSize)
|
|
|
|
: await estimateMultiple(input, res, config, this.inputSize);
|
2020-12-17 00:36:24 +01:00
|
|
|
|
|
|
|
res.heatmapScores.dispose();
|
|
|
|
res.offsets.dispose();
|
|
|
|
res.displacementFwd.dispose();
|
|
|
|
res.displacementBwd.dispose();
|
|
|
|
|
|
|
|
return poses;
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
dispose() {
|
|
|
|
this.baseModel.dispose();
|
|
|
|
}
|
|
|
|
}
|
2020-11-08 18:26:45 +01:00
|
|
|
|
2021-02-08 17:39:09 +01:00
|
|
|
export async function load(config) {
|
2021-04-12 14:29:52 +02:00
|
|
|
if (!model) {
|
|
|
|
model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath));
|
|
|
|
if (!model || !model.modelUrl) log('load model failed:', config.body.modelPath);
|
|
|
|
else if (config.debug) log('load model:', model.modelUrl);
|
|
|
|
} else if (config.debug) log('cached model:', model.modelUrl);
|
2021-04-24 17:49:26 +02:00
|
|
|
const mobilenet = new posenetModel.BaseModel(model);
|
2021-04-12 14:29:52 +02:00
|
|
|
const poseNet = new PoseNet(mobilenet);
|
|
|
|
return poseNet;
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|