human/src/body/modelPoseNet.js

48 lines
1.9 KiB
JavaScript
Raw Normal View History

2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-11-10 02:13:38 +01:00
import * as modelMobileNet from './modelMobileNet';
import * as decodeMultiple from './decodeMultiple';
import * as util from './util';
2020-10-12 01:22:43 +02:00
class PoseNet {
constructor(net) {
2020-10-12 01:22:43 +02:00
this.baseModel = net;
2020-11-08 18:26:45 +01:00
this.outputStride = 16;
2020-10-12 01:22:43 +02:00
}
2020-10-14 17:43:33 +02:00
async estimatePoses(input, config) {
2020-11-06 17:39:39 +01:00
return new Promise(async (resolve) => {
const height = input.shape[1];
const width = input.shape[2];
2020-11-08 18:32:31 +01:00
const resized = util.resizeTo(input, [config.body.inputSize, config.body.inputSize]);
2020-11-08 18:26:45 +01:00
const res = this.baseModel.predict(resized);
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
2020-11-06 17:39:39 +01:00
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
2020-11-08 18:26:45 +01:00
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, this.outputStride, config.body.maxDetections, config.body.scoreThreshold, config.body.nmsRadius);
2020-11-08 18:32:31 +01:00
const resultPoses = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
2020-11-08 18:26:45 +01:00
res.heatmapScores.dispose();
res.offsets.dispose();
res.displacementFwd.dispose();
res.displacementBwd.dispose();
2020-11-06 17:39:39 +01:00
resized.dispose();
resolve(resultPoses);
});
2020-10-12 01:22:43 +02:00
}
dispose() {
this.baseModel.dispose();
}
}
exports.PoseNet = PoseNet;
2020-11-08 18:26:45 +01:00
async function load(config) {
2020-11-17 16:18:15 +01:00
const graphModel = await tf.loadGraphModel(config.body.modelPath);
2020-11-08 18:26:45 +01:00
const mobilenet = new modelMobileNet.MobileNet(graphModel, this.outputStride);
2020-11-07 16:37:19 +01:00
// eslint-disable-next-line no-console
2020-11-08 18:26:45 +01:00
console.log(`Human: load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`);
return new PoseNet(mobilenet);
2020-10-12 01:22:43 +02:00
}
exports.load = load;