human/src/posenet/posenet.ts

67 lines
2.5 KiB
TypeScript
Raw Normal View History

2021-04-09 14:07:58 +02:00
import { log, join } from '../helpers';
2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-12-17 00:36:24 +01:00
import * as modelBase from './modelBase';
2020-11-10 02:13:38 +01:00
import * as decodeMultiple from './decodeMultiple';
2020-12-17 00:36:24 +01:00
import * as decodePose from './decodePose';
2020-11-10 02:13:38 +01:00
import * as util from './util';
2020-10-12 01:22:43 +02:00
2021-03-11 16:26:14 +01:00
async function estimateMultiple(input, res, config, inputSize) {
2020-12-17 00:36:24 +01:00
return new Promise(async (resolve) => {
const allTensorBuffers = await util.toTensorBuffers3D([res.heatmapScores, res.offsets, res.displacementFwd, res.displacementBwd]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
2021-03-11 16:26:14 +01:00
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, config.body.nmsRadius, config.body.maxDetections, config.body.scoreThreshold);
const scaled = util.scaleAndFlipPoses(poses, [input.shape[1], input.shape[2]], [inputSize, inputSize]);
2020-12-17 00:36:24 +01:00
resolve(scaled);
});
}
2021-03-11 16:26:14 +01:00
async function estimateSingle(input, res, config, inputSize) {
2020-12-17 00:36:24 +01:00
return new Promise(async (resolve) => {
2021-03-11 16:26:14 +01:00
const pose = await decodePose.decodeSinglePose(res.heatmapScores, res.offsets, config.body.scoreThreshold);
const scaled = util.scaleAndFlipPoses([pose], [input.shape[1], input.shape[2]], [inputSize, inputSize]);
2020-12-17 00:36:24 +01:00
resolve(scaled);
});
}
2021-02-08 17:39:09 +01:00
export class PoseNet {
baseModel: any;
2021-03-11 16:26:14 +01:00
inputSize: number
2020-12-17 00:36:24 +01:00
constructor(model) {
this.baseModel = model;
2021-03-11 16:26:14 +01:00
this.inputSize = model.model.inputs[0].shape[1];
2021-03-12 18:54:08 +01:00
if (this.inputSize < 128) this.inputSize = 257;
2020-10-12 01:22:43 +02:00
}
2020-10-14 17:43:33 +02:00
async estimatePoses(input, config) {
2021-03-11 16:26:14 +01:00
const resized = util.resizeTo(input, [this.inputSize, this.inputSize]);
2020-12-17 00:36:24 +01:00
const res = this.baseModel.predict(resized, config);
2021-03-11 16:26:14 +01:00
const poses = (config.body.maxDetections < 2)
? await estimateSingle(input, res, config, this.inputSize)
: await estimateMultiple(input, res, config, this.inputSize);
2020-12-17 00:36:24 +01:00
res.heatmapScores.dispose();
res.offsets.dispose();
res.displacementFwd.dispose();
res.displacementBwd.dispose();
resized.dispose();
return poses;
2020-10-12 01:22:43 +02:00
}
dispose() {
this.baseModel.dispose();
}
}
2020-11-08 18:26:45 +01:00
2021-02-08 17:39:09 +01:00
export async function load(config) {
2021-04-09 14:07:58 +02:00
const model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath));
2020-12-17 00:36:24 +01:00
const mobilenet = new modelBase.BaseModel(model);
2021-04-09 14:07:58 +02:00
if (!model || !model.modelUrl) log('load model failed:', config.body.modelPath);
else if (config.debug) log('load model:', model.modelUrl);
return new PoseNet(mobilenet);
2020-10-12 01:22:43 +02:00
}