2021-04-09 14:07:58 +02:00
|
|
|
import { log, join } from '../helpers';
|
2020-11-18 14:26:28 +01:00
|
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
2021-04-25 22:56:10 +02:00
|
|
|
import * as box from './box';
|
|
|
|
import * as util from './util';
|
2020-10-12 01:22:43 +02:00
|
|
|
|
2021-04-25 22:56:10 +02:00
|
|
|
const keypointsCount = 6;
|
2020-10-13 04:01:35 +02:00
|
|
|
|
2020-10-12 01:22:43 +02:00
|
|
|
function decodeBounds(boxOutputs, anchors, inputSize) {
|
|
|
|
const boxStarts = tf.slice(boxOutputs, [0, 1], [-1, 2]);
|
|
|
|
const centers = tf.add(boxStarts, anchors);
|
|
|
|
const boxSizes = tf.slice(boxOutputs, [0, 3], [-1, 2]);
|
|
|
|
const boxSizesNormalized = tf.div(boxSizes, inputSize);
|
|
|
|
const centersNormalized = tf.div(centers, inputSize);
|
|
|
|
const halfBoxSize = tf.div(boxSizesNormalized, 2);
|
|
|
|
const starts = tf.sub(centersNormalized, halfBoxSize);
|
|
|
|
const ends = tf.add(centersNormalized, halfBoxSize);
|
|
|
|
const startNormalized = tf.mul(starts, inputSize);
|
|
|
|
const endNormalized = tf.mul(ends, inputSize);
|
|
|
|
const concatAxis = 1;
|
|
|
|
return tf.concat2d([startNormalized, endNormalized], concatAxis);
|
|
|
|
}
|
2020-10-13 04:01:35 +02:00
|
|
|
|
2021-02-08 17:39:09 +01:00
|
|
|
export class BlazeFaceModel {
|
2021-03-11 17:44:22 +01:00
|
|
|
model: any;
|
2021-02-08 17:39:09 +01:00
|
|
|
anchorsData: any;
|
|
|
|
anchors: any;
|
2021-03-11 17:44:22 +01:00
|
|
|
inputSize: number;
|
2021-02-08 17:39:09 +01:00
|
|
|
config: any;
|
|
|
|
|
2020-10-12 01:22:43 +02:00
|
|
|
constructor(model, config) {
|
2021-03-11 17:44:22 +01:00
|
|
|
this.model = model;
|
2021-04-25 22:56:10 +02:00
|
|
|
this.anchorsData = util.generateAnchors(model.inputs[0].shape[1]);
|
2020-10-12 01:22:43 +02:00
|
|
|
this.anchors = tf.tensor2d(this.anchorsData);
|
2021-03-11 17:44:22 +01:00
|
|
|
this.inputSize = model.inputs[0].shape[2];
|
2020-11-06 21:35:58 +01:00
|
|
|
this.config = config;
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|
|
|
|
|
2020-10-13 04:01:35 +02:00
|
|
|
async getBoundingBoxes(inputImage) {
|
2020-10-16 16:12:12 +02:00
|
|
|
// sanity check on input
|
|
|
|
if ((!inputImage) || (inputImage.isDisposedInternal) || (inputImage.shape.length !== 4) || (inputImage.shape[1] < 1) || (inputImage.shape[2] < 1)) return null;
|
2021-03-10 15:44:45 +01:00
|
|
|
const [batch, boxes, scores] = tf.tidy(() => {
|
2021-03-11 17:44:22 +01:00
|
|
|
const resizedImage = inputImage.resizeBilinear([this.inputSize, this.inputSize]);
|
2020-11-06 21:35:58 +01:00
|
|
|
// const normalizedImage = tf.mul(tf.sub(resizedImage.div(255), 0.5), 2);
|
2021-03-09 19:15:40 +01:00
|
|
|
const normalizedImage = resizedImage.div(127.5).sub(0.5);
|
2021-03-11 17:44:22 +01:00
|
|
|
const batchedPrediction = this.model.predict(normalizedImage);
|
2021-03-10 15:44:45 +01:00
|
|
|
let batchOut;
|
2020-10-16 02:20:37 +02:00
|
|
|
// are we using tfhub or pinto converted model?
|
|
|
|
if (Array.isArray(batchedPrediction)) {
|
|
|
|
const sorted = batchedPrediction.sort((a, b) => a.size - b.size);
|
|
|
|
const concat384 = tf.concat([sorted[0], sorted[2]], 2); // dim: 384, 1 + 16
|
|
|
|
const concat512 = tf.concat([sorted[1], sorted[3]], 2); // dim: 512, 1 + 16
|
|
|
|
const concat = tf.concat([concat512, concat384], 1);
|
2021-03-10 15:44:45 +01:00
|
|
|
batchOut = concat.squeeze(0);
|
2020-10-16 02:20:37 +02:00
|
|
|
} else {
|
2021-03-10 15:44:45 +01:00
|
|
|
batchOut = batchedPrediction.squeeze(); // when using tfhub model
|
2020-10-16 02:20:37 +02:00
|
|
|
}
|
2021-03-11 17:44:22 +01:00
|
|
|
const boxesOut = decodeBounds(batchOut, this.anchors, [this.inputSize, this.inputSize]);
|
2021-03-10 15:44:45 +01:00
|
|
|
const logits = tf.slice(batchOut, [0, 0], [-1, 1]);
|
2020-11-08 18:26:45 +01:00
|
|
|
const scoresOut = tf.sigmoid(logits).squeeze();
|
2021-03-10 15:44:45 +01:00
|
|
|
return [batchOut, boxesOut, scoresOut];
|
2020-10-12 01:22:43 +02:00
|
|
|
});
|
2021-04-25 19:16:04 +02:00
|
|
|
const boxIndicesTensor = await tf.image.nonMaxSuppressionAsync(boxes, scores, this.config.face.detector.maxDetected, this.config.face.detector.iouThreshold, this.config.face.detector.minConfidence);
|
2020-11-03 00:54:03 +01:00
|
|
|
const boxIndices = boxIndicesTensor.arraySync();
|
2020-10-12 01:22:43 +02:00
|
|
|
boxIndicesTensor.dispose();
|
2020-10-17 16:06:02 +02:00
|
|
|
const boundingBoxesMap = boxIndices.map((boxIndex) => tf.slice(boxes, [boxIndex, 0], [1, -1]));
|
2020-11-03 00:54:03 +01:00
|
|
|
const boundingBoxes = boundingBoxesMap.map((boundingBox) => {
|
|
|
|
const vals = boundingBox.arraySync();
|
2020-10-13 04:01:35 +02:00
|
|
|
boundingBox.dispose();
|
|
|
|
return vals;
|
2020-11-03 00:54:03 +01:00
|
|
|
});
|
|
|
|
|
2020-11-09 20:26:10 +01:00
|
|
|
const scoresVal = scores.dataSync();
|
2021-02-08 18:47:38 +01:00
|
|
|
const annotatedBoxes: Array<{ box: any, landmarks: any, anchor: any, confidence: number }> = [];
|
2020-11-26 16:37:04 +01:00
|
|
|
for (let i = 0; i < boundingBoxes.length; i++) {
|
2020-10-17 16:06:02 +02:00
|
|
|
const boxIndex = boxIndices[i];
|
2020-11-09 20:26:10 +01:00
|
|
|
const confidence = scoresVal[boxIndex];
|
2020-11-17 23:42:44 +01:00
|
|
|
if (confidence > this.config.face.detector.minConfidence) {
|
2021-04-25 22:56:10 +02:00
|
|
|
const localBox = box.createBox(boundingBoxes[i]);
|
2020-11-09 20:26:10 +01:00
|
|
|
const anchor = this.anchorsData[boxIndex];
|
2021-04-25 22:56:10 +02:00
|
|
|
const landmarks = tf.tidy(() => tf.slice(batch, [boxIndex, keypointsCount - 1], [1, -1]).squeeze().reshape([keypointsCount, -1]));
|
|
|
|
annotatedBoxes.push({ box: localBox, landmarks, anchor, confidence });
|
2020-11-09 20:26:10 +01:00
|
|
|
}
|
2020-10-12 01:22:43 +02:00
|
|
|
}
|
2021-03-10 15:44:45 +01:00
|
|
|
batch.dispose();
|
2020-10-12 01:22:43 +02:00
|
|
|
boxes.dispose();
|
|
|
|
scores.dispose();
|
|
|
|
return {
|
|
|
|
boxes: annotatedBoxes,
|
2021-03-11 17:44:22 +01:00
|
|
|
scaleFactor: [inputImage.shape[2] / this.inputSize, inputImage.shape[1] / this.inputSize],
|
2020-10-12 01:22:43 +02:00
|
|
|
};
|
|
|
|
}
|
|
|
|
}
|
2020-10-13 04:01:35 +02:00
|
|
|
|
2021-02-08 17:39:09 +01:00
|
|
|
export async function load(config) {
|
2021-04-09 14:07:58 +02:00
|
|
|
const model = await tf.loadGraphModel(join(config.modelBasePath, config.face.detector.modelPath), { fromTFHub: config.face.detector.modelPath.includes('tfhub.dev') });
|
|
|
|
const blazeFace = new BlazeFaceModel(model, config);
|
|
|
|
if (!model || !model.modelUrl) log('load model failed:', config.face.detector.modelPath);
|
|
|
|
else if (config.debug) log('load model:', model.modelUrl);
|
|
|
|
return blazeFace;
|
2020-10-13 04:01:35 +02:00
|
|
|
}
|