human/src/body/modelPoseNet.js

67 lines
2.3 KiB
JavaScript
Raw Normal View History

2020-12-08 15:00:44 +01:00
import { log } from '../log.js';
2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-12-17 00:36:24 +01:00
import * as modelBase from './modelBase';
2020-11-10 02:13:38 +01:00
import * as decodeMultiple from './decodeMultiple';
2020-12-17 00:36:24 +01:00
import * as decodePose from './decodePose';
2020-11-10 02:13:38 +01:00
import * as util from './util';
2020-10-12 01:22:43 +02:00
2020-12-17 00:36:24 +01:00
async function estimateMultiple(input, res, config) {
return new Promise(async (resolve) => {
const height = input.shape[1];
const width = input.shape[2];
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config);
const scaled = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
resolve(scaled);
});
}
async function estimateSingle(input, res, config) {
return new Promise(async (resolve) => {
const height = input.shape[1];
const width = input.shape[2];
const pose = await decodePose.decodeSinglePose(res.heatmapScores, res.offsets, config);
const poses = [pose];
const scaled = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
resolve(scaled);
});
}
2020-10-12 01:22:43 +02:00
class PoseNet {
2020-12-17 00:36:24 +01:00
constructor(model) {
this.baseModel = model;
2020-10-12 01:22:43 +02:00
}
2020-10-14 17:43:33 +02:00
async estimatePoses(input, config) {
2020-12-17 00:36:24 +01:00
const resized = util.resizeTo(input, [config.body.inputSize, config.body.inputSize]);
const res = this.baseModel.predict(resized, config);
const poses = (config.body.maxDetections < 2) ? await estimateSingle(input, res, config) : await estimateMultiple(input, res, config);
res.heatmapScores.dispose();
res.offsets.dispose();
res.displacementFwd.dispose();
res.displacementBwd.dispose();
resized.dispose();
return poses;
2020-10-12 01:22:43 +02:00
}
dispose() {
this.baseModel.dispose();
}
}
exports.PoseNet = PoseNet;
2020-11-08 18:26:45 +01:00
async function load(config) {
2020-12-17 00:36:24 +01:00
const model = await tf.loadGraphModel(config.body.modelPath);
const mobilenet = new modelBase.BaseModel(model);
log(`load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`);
return new PoseNet(mobilenet);
2020-10-12 01:22:43 +02:00
}
exports.load = load;