human/src/body/modelPoseNet.js

81 lines
3.6 KiB
JavaScript
Raw Normal View History

2020-10-12 01:22:43 +02:00
const tf = require('@tensorflow/tfjs');
const modelMobileNet = require('./modelMobileNet');
const decodeMultiple = require('./decodeMultiple');
const util = require('./util');
class PoseNet {
constructor(net) {
2020-10-12 01:22:43 +02:00
this.baseModel = net;
}
/**
* Infer through PoseNet, and estimates multiple poses using the outputs.
* This does standard ImageNet pre-processing before inferring through the
* model. The image should pixels should have values [0-255]. It detects
* multiple poses and finds their parts from part scores and displacement
* vectors using a fast greedy decoding algorithm. It returns up to
* `config.maxDetections` object instance detections in decreasing root
* score order.
*
* @param input
* ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement) The input
* image to feed through the network.
*
* @param config MultiPoseEstimationConfig object that contains parameters
* for the PoseNet inference using multiple pose estimation.
*
* @return An array of poses and their scores, each containing keypoints and
* the corresponding keypoint scores. The positions of the keypoints are
* in the same scale as the original image
*/
2020-10-14 17:43:33 +02:00
async estimatePoses(input, config) {
2020-11-06 17:39:39 +01:00
return new Promise(async (resolve) => {
const outputStride = config.outputStride;
// const inputResolution = config.inputResolution;
const height = input.shape[1];
const width = input.shape[2];
const resized = util.resizeTo(input, [config.inputResolution, config.inputResolution]);
const { heatmapScores, offsets, displacementFwd, displacementBwd } = this.baseModel.predict(resized);
const allTensorBuffers = await util.toTensorBuffers3D([heatmapScores, offsets, displacementFwd, displacementBwd]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, outputStride, config.maxDetections, config.scoreThreshold, config.nmsRadius);
const resultPoses = util.scaleAndFlipPoses(poses, [height, width], [config.inputResolution, config.inputResolution]);
heatmapScores.dispose();
offsets.dispose();
displacementFwd.dispose();
displacementBwd.dispose();
resized.dispose();
resolve(resultPoses);
});
2020-10-12 01:22:43 +02:00
}
dispose() {
this.baseModel.dispose();
}
}
exports.PoseNet = PoseNet;
async function loadMobileNet(config) {
const graphModel = await tf.loadGraphModel(config.modelPath);
const mobilenet = new modelMobileNet.MobileNet(graphModel, config.outputStride);
2020-11-07 16:37:19 +01:00
// eslint-disable-next-line no-console
console.log(`Human: load model: ${config.modelPath.match(/\/(.*)\./)[1]}`);
return new PoseNet(mobilenet);
2020-10-12 01:22:43 +02:00
}
/**
* Loads the PoseNet model instance from a checkpoint, with the MobileNet architecture. The model to be loaded is configurable using the
* config dictionary ModelConfig. Please find more details in the documentation of the ModelConfig.
*
* @param config ModelConfig dictionary that contains parameters for
* the PoseNet loading process. Please find more details of each parameters
* in the documentation of the ModelConfig interface. The predefined
* `MOBILENET_V1_CONFIG` and `RESNET_CONFIG` can also be used as references
* for defining your customized config.
*/
async function load(config) {
return loadMobileNet(config);
}
exports.load = load;