face-api/build/faceProcessor/FaceProcessor.js

55 lines
2.4 KiB
JavaScript
Raw Normal View History

2020-08-20 02:10:42 +02:00
import * as tf from '@tensorflow/tfjs-core';
import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
import { NetInput } from '../dom';
import { NeuralNetwork } from '../NeuralNetwork';
import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { seperateWeightMaps } from './util';
export class FaceProcessor extends NeuralNetwork {
2020-08-18 14:04:33 +02:00
constructor(_name, faceFeatureExtractor) {
super(_name);
this._faceFeatureExtractor = faceFeatureExtractor;
}
get faceFeatureExtractor() {
return this._faceFeatureExtractor;
}
runNet(input) {
const { params } = this;
if (!params) {
throw new Error(`${this._name} - load model before inference`);
}
return tf.tidy(() => {
2020-08-20 02:10:42 +02:00
const bottleneckFeatures = input instanceof NetInput
2020-08-18 14:04:33 +02:00
? this.faceFeatureExtractor.forwardInput(input)
: input;
2020-08-20 02:10:42 +02:00
return fullyConnectedLayer(bottleneckFeatures.as2D(bottleneckFeatures.shape[0], -1), params.fc);
2020-08-18 14:04:33 +02:00
});
}
dispose(throwOnRedispose = true) {
this.faceFeatureExtractor.dispose(throwOnRedispose);
super.dispose(throwOnRedispose);
}
loadClassifierParams(weights) {
const { params, paramMappings } = this.extractClassifierParams(weights);
this._params = params;
this._paramMappings = paramMappings;
}
extractClassifierParams(weights) {
2020-08-20 02:10:42 +02:00
return extractParams(weights, this.getClassifierChannelsIn(), this.getClassifierChannelsOut());
2020-08-18 14:04:33 +02:00
}
extractParamsFromWeigthMap(weightMap) {
2020-08-20 02:10:42 +02:00
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap);
2020-08-18 14:04:33 +02:00
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap);
2020-08-20 02:10:42 +02:00
return extractParamsFromWeigthMap(classifierMap);
2020-08-18 14:04:33 +02:00
}
extractParams(weights) {
const cIn = this.getClassifierChannelsIn();
const cOut = this.getClassifierChannelsOut();
const classifierWeightSize = (cOut * cIn) + cOut;
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize);
const classifierWeights = weights.slice(weights.length - classifierWeightSize);
this.faceFeatureExtractor.extractWeights(featureExtractorWeights);
return this.extractClassifierParams(classifierWeights);
}
}
//# sourceMappingURL=FaceProcessor.js.map