human/src/body/modelBase.js

46 lines
1.5 KiB
JavaScript
Raw Normal View History

2020-10-12 01:22:43 +02:00
const tf = require('@tensorflow/tfjs');
2020-10-12 01:22:43 +02:00
class BaseModel {
constructor(model, outputStride) {
this.model = model;
this.outputStride = outputStride;
}
/**
* Predicts intermediate Tensor representations.
*
* @param input The input RGB image of the base model.
* A Tensor of shape: [`inputResolution`, `inputResolution`, 3].
*
* @return A dictionary of base model's intermediate predictions.
* The returned dictionary should contains the following elements:
* heatmapScores: A Tensor3D that represents the heatmapScores.
* offsets: A Tensor3D that represents the offsets.
* displacementFwd: A Tensor3D that represents the forward displacement.
* displacementBwd: A Tensor3D that represents the backward displacement.
*/
predict(input) {
return tf.tidy(() => {
const asFloat = this.preprocessInput(input.toFloat());
const asBatch = asFloat.expandDims(0);
const results = this.model.predict(asBatch);
const results3d = results.map((y) => y.squeeze([0]));
const namedResults = this.nameOutputResults(results3d);
return {
heatmapScores: namedResults.heatmap.sigmoid(),
offsets: namedResults.offsets,
displacementFwd: namedResults.displacementFwd,
displacementBwd: namedResults.displacementBwd,
};
});
}
/**
* Releases the CPU and GPU memory allocated by the model.
*/
dispose() {
this.model.dispose();
}
}
exports.BaseModel = BaseModel;