human/src/posenet/posenet.ts

65 lines
2.4 KiB
TypeScript
Raw Normal View History

2021-02-08 17:39:09 +01:00
import { log } from '../log';
2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-12-17 00:36:24 +01:00
import * as modelBase from './modelBase';
2020-11-10 02:13:38 +01:00
import * as decodeMultiple from './decodeMultiple';
2020-12-17 00:36:24 +01:00
import * as decodePose from './decodePose';
2020-11-10 02:13:38 +01:00
import * as util from './util';
2020-10-12 01:22:43 +02:00
2021-03-11 16:26:14 +01:00
async function estimateMultiple(input, res, config, inputSize) {
2020-12-17 00:36:24 +01:00
return new Promise(async (resolve) => {
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
2021-03-11 16:26:14 +01:00
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config.body.nmsRadius, config.body.maxDetections, config.body.scoreThreshold);
const scaled = util.scaleAndFlipPoses(poses, [input.shape[1], input.shape[2]], [inputSize, inputSize]);
2020-12-17 00:36:24 +01:00
resolve(scaled);
});
}
2021-03-11 16:26:14 +01:00
async function estimateSingle(input, res, config, inputSize) {
2020-12-17 00:36:24 +01:00
return new Promise(async (resolve) => {
2021-03-11 16:26:14 +01:00
const pose = await decodePose.decodeSinglePose(res.heatmapScores, res.offsets, config.body.scoreThreshold);
const scaled = util.scaleAndFlipPoses([pose], [input.shape[1], input.shape[2]], [inputSize, inputSize]);
2020-12-17 00:36:24 +01:00
resolve(scaled);
});
}
2021-02-08 17:39:09 +01:00
export class PoseNet {
baseModel: any;
2021-03-11 16:26:14 +01:00
inputSize: number
2020-12-17 00:36:24 +01:00
constructor(model) {
this.baseModel = model;
2021-03-11 16:26:14 +01:00
this.inputSize = model.model.inputs[0].shape[1];
2020-10-12 01:22:43 +02:00
}
2020-10-14 17:43:33 +02:00
async estimatePoses(input, config) {
2021-03-11 16:26:14 +01:00
const resized = util.resizeTo(input, [this.inputSize, this.inputSize]);
2020-12-17 00:36:24 +01:00
const res = this.baseModel.predict(resized, config);
2021-03-11 16:26:14 +01:00
const poses = (config.body.maxDetections < 2)
? await estimateSingle(input, res, config, this.inputSize)
: await estimateMultiple(input, res, config, this.inputSize);
2020-12-17 00:36:24 +01:00
res.heatmapScores.dispose();
res.offsets.dispose();
res.displacementFwd.dispose();
res.displacementBwd.dispose();
resized.dispose();
return poses;
2020-10-12 01:22:43 +02:00
}
dispose() {
this.baseModel.dispose();
}
}
2020-11-08 18:26:45 +01:00
2021-02-08 17:39:09 +01:00
export async function load(config) {
2020-12-17 00:36:24 +01:00
const model = await tf.loadGraphModel(config.body.modelPath);
const mobilenet = new modelBase.BaseModel(model);
2021-03-02 17:27:42 +01:00
if (config.debug) log(`load model: ${config.body.modelPath.match(/\/(.*)\./)[1]}`);
return new PoseNet(mobilenet);
2020-10-12 01:22:43 +02:00
}