mirror of https://github.com/vladmandic/human
67 lines
2.3 KiB
JavaScript
67 lines
2.3 KiB
JavaScript
import { log } from '../log.js';
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
|
import * as modelBase from './modelBase';
|
|
import * as decodeMultiple from './decodeMultiple';
|
|
import * as decodePose from './decodePose';
|
|
import * as util from './util';
|
|
|
|
async function estimateMultiple(input, res, config) {
|
|
return new Promise(async (resolve) => {
|
|
const height = input.shape[1];
|
|
const width = input.shape[2];
|
|
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
|
|
const scoresBuffer = allTensorBuffers[0];
|
|
const offsetsBuffer = allTensorBuffers[1];
|
|
const displacementsFwdBuffer = allTensorBuffers[2];
|
|
const displacementsBwdBuffer = allTensorBuffers[3];
|
|
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config);
|
|
const scaled = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
|
|
resolve(scaled);
|
|
});
|
|
}
|
|
|
|
async function estimateSingle(input, res, config) {
|
|
return new Promise(async (resolve) => {
|
|
const height = input.shape[1];
|
|
const width = input.shape[2];
|
|
const pose = await decodePose.decodeSinglePose(res.heatmapScores, res.offsets, config);
|
|
const poses = [pose];
|
|
const scaled = util.scaleAndFlipPoses(poses, [height, width], [config.body.inputSize, config.body.inputSize]);
|
|
resolve(scaled);
|
|
});
|
|
}
|
|
|
|
class PoseNet {
|
|
constructor(model) {
|
|
this.baseModel = model;
|
|
}
|
|
|
|
async estimatePoses(input, config) {
|
|
const resized = util.resizeTo(input, [config.body.inputSize, config.body.inputSize]);
|
|
const res = this.baseModel.predict(resized, config);
|
|
|
|
const poses = (config.body.maxDetections < 2) ? await estimateSingle(input, res, config) : await estimateMultiple(input, res, config);
|
|
|
|
res.heatmapScores.dispose();
|
|
res.offsets.dispose();
|
|
res.displacementFwd.dispose();
|
|
res.displacementBwd.dispose();
|
|
resized.dispose();
|
|
|
|
return poses;
|
|
}
|
|
|
|
dispose() {
|
|
this.baseModel.dispose();
|
|
}
|
|
}
|
|
exports.PoseNet = PoseNet;
|
|
|
|
async function load(config) {
|
|
const model = await tf.loadGraphModel(config.body.modelPath);
|
|
const mobilenet = new modelBase.BaseModel(model);
|
|
log(`load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`);
|
|
return new PoseNet(mobilenet);
|
|
}
|
|
exports.load = load;
|