face-api/build/faceExpressionNet/FaceExpressionNet.js

45 lines
1.7 KiB
JavaScript

"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.FaceExpressionNet = void 0;
const tf = require("@tensorflow/tfjs-core");
const dom_1 = require("../dom");
const FaceFeatureExtractor_1 = require("../faceFeatureExtractor/FaceFeatureExtractor");
const FaceProcessor_1 = require("../faceProcessor/FaceProcessor");
const FaceExpressions_1 = require("./FaceExpressions");
class FaceExpressionNet extends FaceProcessor_1.FaceProcessor {
constructor(faceFeatureExtractor = new FaceFeatureExtractor_1.FaceFeatureExtractor()) {
super('FaceExpressionNet', faceFeatureExtractor);
}
forwardInput(input) {
return tf.tidy(() => tf.softmax(this.runNet(input)));
}
async forward(input) {
return this.forwardInput(await dom_1.toNetInput(input));
}
async predictExpressions(input) {
const netInput = await dom_1.toNetInput(input);
const out = await this.forwardInput(netInput);
const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async (t) => {
const data = await t.data();
t.dispose();
return data;
}));
out.dispose();
const predictionsByBatch = probabilitesByBatch
.map(probabilites => new FaceExpressions_1.FaceExpressions(probabilites));
return netInput.isBatchInput
? predictionsByBatch
: predictionsByBatch[0];
}
getDefaultModelName() {
return 'face_expression_model';
}
getClassifierChannelsIn() {
return 256;
}
getClassifierChannelsOut() {
return 7;
}
}
exports.FaceExpressionNet = FaceExpressionNet;
//# sourceMappingURL=FaceExpressionNet.js.map