human/src/posenet/modelBase.ts

41 lines
1.3 KiB
TypeScript
Raw Normal View History

2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-12-17 00:36:24 +01:00
const imageNetMean = [-123.15, -115.90, -103.06];
function nameOutputResultsMobileNet(results) {
const [offsets, heatmap, displacementFwd, displacementBwd] = results;
return { offsets, heatmap, displacementFwd, displacementBwd };
}
function nameOutputResultsResNet(results) {
const [displacementFwd, displacementBwd, offsets, heatmap] = results;
return { offsets, heatmap, displacementFwd, displacementBwd };
}
2021-02-08 17:39:09 +01:00
export class BaseModel {
model: any;
2020-12-17 00:36:24 +01:00
constructor(model) {
2020-10-12 01:22:43 +02:00
this.model = model;
}
2020-12-17 00:36:24 +01:00
predict(input, config) {
2020-10-12 01:22:43 +02:00
return tf.tidy(() => {
2021-03-04 16:33:08 +01:00
const asFloat = (config.body.modelType === 'posenet-resnet') ? input.toFloat().add(imageNetMean) : input.toFloat().div(127.5).sub(1.0);
2020-10-12 01:22:43 +02:00
const asBatch = asFloat.expandDims(0);
const results = this.model.predict(asBatch);
const results3d = results.map((y) => y.squeeze([0]));
2021-03-04 16:33:08 +01:00
const namedResults = (config.body.modelType === 'posenet-resnet') ? nameOutputResultsResNet(results3d) : nameOutputResultsMobileNet(results3d);
2020-10-12 01:22:43 +02:00
return {
heatmapScores: namedResults.heatmap.sigmoid(),
offsets: namedResults.offsets,
displacementFwd: namedResults.displacementFwd,
displacementBwd: namedResults.displacementBwd,
};
});
}
dispose() {
this.model.dispose();
}
}