mirror of https://github.com/vladmandic/human
41 lines
1.3 KiB
TypeScript
41 lines
1.3 KiB
TypeScript
![]() |
import * as tf from '../../dist/tfjs.esm.js';
|
||
![]() |
|
||
![]() |
const imageNetMean = [-123.15, -115.90, -103.06];
|
||
|
|
||
|
function nameOutputResultsMobileNet(results) {
|
||
|
const [offsets, heatmap, displacementFwd, displacementBwd] = results;
|
||
|
return { offsets, heatmap, displacementFwd, displacementBwd };
|
||
|
}
|
||
|
|
||
|
function nameOutputResultsResNet(results) {
|
||
|
const [displacementFwd, displacementBwd, offsets, heatmap] = results;
|
||
|
return { offsets, heatmap, displacementFwd, displacementBwd };
|
||
|
}
|
||
|
|
||
![]() |
export class BaseModel {
|
||
|
model: any;
|
||
![]() |
constructor(model) {
|
||
![]() |
this.model = model;
|
||
|
}
|
||
|
|
||
![]() |
predict(input, config) {
|
||
![]() |
return tf.tidy(() => {
|
||
![]() |
const asFloat = (config.body.modelType === 'ResNet') ? input.toFloat().add(imageNetMean) : input.toFloat().div(127.5).sub(1.0);
|
||
![]() |
const asBatch = asFloat.expandDims(0);
|
||
|
const results = this.model.predict(asBatch);
|
||
|
const results3d = results.map((y) => y.squeeze([0]));
|
||
![]() |
const namedResults = (config.body.modelType === 'ResNet') ? nameOutputResultsResNet(results3d) : nameOutputResultsMobileNet(results3d);
|
||
![]() |
return {
|
||
|
heatmapScores: namedResults.heatmap.sigmoid(),
|
||
|
offsets: namedResults.offsets,
|
||
|
displacementFwd: namedResults.displacementFwd,
|
||
|
displacementBwd: namedResults.displacementBwd,
|
||
|
};
|
||
|
});
|
||
|
}
|
||
|
|
||
|
dispose() {
|
||
|
this.model.dispose();
|
||
|
}
|
||
|
}
|