human/src/body/modelPoseNet.js

48 lines
1.9 KiB
JavaScript
Raw Normal View History

2020-10-12 01:22:43 +02:00
const tf = require('@tensorflow/tfjs');
const modelMobileNet = require('./modelMobileNet');
const decodeMultiple = require('./decodeMultiple');
const util = require('./util');
class PoseNet {
constructor(net) {
2020-10-12 01:22:43 +02:00
this.baseModel = net;
2020-11-08 18:26:45 +01:00
this.outputStride = 16;
2020-10-12 01:22:43 +02:00
}
2020-10-14 17:43:33 +02:00
async estimatePoses(input, config) {
2020-11-06 17:39:39 +01:00
return new Promise(async (resolve) => {
const height = input.shape[1];
const width = input.shape[2];
2020-11-08 18:26:45 +01:00
const resized = util.resizeTo(input, [config.body.inputResolution, config.body.inputResolution]);
const res = this.baseModel.predict(resized);
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
2020-11-06 17:39:39 +01:00
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
2020-11-08 18:26:45 +01:00
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, this.outputStride, config.body.maxDetections, config.body.scoreThreshold, config.body.nmsRadius);
const resultPoses = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputResolution, config.body.inputResolution]);
res.heatmapScores.dispose();
res.offsets.dispose();
res.displacementFwd.dispose();
res.displacementBwd.dispose();
2020-11-06 17:39:39 +01:00
resized.dispose();
resolve(resultPoses);
});
2020-10-12 01:22:43 +02:00
}
dispose() {
this.baseModel.dispose();
}
}
exports.PoseNet = PoseNet;
2020-11-08 18:26:45 +01:00
async function load(config) {
const graphModel = await tf.loadGraphModel(config.body.modelPath);
const mobilenet = new modelMobileNet.MobileNet(graphModel, this.outputStride);
2020-11-07 16:37:19 +01:00
// eslint-disable-next-line no-console
2020-11-08 18:26:45 +01:00
console.log(`Human: load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`);
return new PoseNet(mobilenet);
2020-10-12 01:22:43 +02:00
}
exports.load = load;