human/src/blazeface/blazeface.ts

138 lines
5.6 KiB
TypeScript
Raw Normal View History

2021-02-08 17:39:09 +01:00
import { log } from '../log';
2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2020-10-12 01:22:43 +02:00
const NUM_LANDMARKS = 6;
2020-10-16 16:48:10 +02:00
function generateAnchors(inputSize) {
const spec = { strides: [inputSize / 16, inputSize / 8], anchors: [2, 6] };
2021-02-08 18:47:38 +01:00
const anchors: Array<[number, number]> = [];
2020-10-16 16:48:10 +02:00
for (let i = 0; i < spec.strides.length; i++) {
const stride = spec.strides[i];
const gridRows = Math.floor((inputSize + stride - 1) / stride);
const gridCols = Math.floor((inputSize + stride - 1) / stride);
const anchorsNum = spec.anchors[i];
2020-10-12 01:22:43 +02:00
for (let gridY = 0; gridY < gridRows; gridY++) {
const anchorY = stride * (gridY + 0.5);
for (let gridX = 0; gridX < gridCols; gridX++) {
const anchorX = stride * (gridX + 0.5);
for (let n = 0; n < anchorsNum; n++) {
anchors.push([anchorX, anchorY]);
}
}
}
}
return anchors;
}
2020-10-13 04:01:35 +02:00
2021-02-08 17:39:09 +01:00
export const disposeBox = (box) => {
2020-10-13 04:01:35 +02:00
box.startEndTensor.dispose();
box.startPoint.dispose();
box.endPoint.dispose();
};
const createBox = (startEndTensor) => ({
startEndTensor,
startPoint: tf.slice(startEndTensor, [0, 0], [-1, 2]),
endPoint: tf.slice(startEndTensor, [0, 2], [-1, 2]),
});
2020-10-12 01:22:43 +02:00
function decodeBounds(boxOutputs, anchors, inputSize) {
const boxStarts = tf.slice(boxOutputs, [0, 1], [-1, 2]);
const centers = tf.add(boxStarts, anchors);
const boxSizes = tf.slice(boxOutputs, [0, 3], [-1, 2]);
const boxSizesNormalized = tf.div(boxSizes, inputSize);
const centersNormalized = tf.div(centers, inputSize);
const halfBoxSize = tf.div(boxSizesNormalized, 2);
const starts = tf.sub(centersNormalized, halfBoxSize);
const ends = tf.add(centersNormalized, halfBoxSize);
const startNormalized = tf.mul(starts, inputSize);
const endNormalized = tf.mul(ends, inputSize);
const concatAxis = 1;
return tf.concat2d([startNormalized, endNormalized], concatAxis);
}
2020-10-13 04:01:35 +02:00
2021-02-08 17:39:09 +01:00
export class BlazeFaceModel {
blazeFaceModel: any;
width: number;
height: number;
anchorsData: any;
anchors: any;
inputSize: number;
config: any;
scaleFaces: number;
2020-10-12 01:22:43 +02:00
constructor(model, config) {
this.blazeFaceModel = model;
this.width = config.face.detector.inputSize;
this.height = config.face.detector.inputSize;
this.anchorsData = generateAnchors(config.face.detector.inputSize);
2020-10-12 01:22:43 +02:00
this.anchors = tf.tensor2d(this.anchorsData);
2020-10-16 02:20:37 +02:00
this.inputSize = tf.tensor1d([this.width, this.height]);
2020-11-06 21:35:58 +01:00
this.config = config;
2020-10-16 00:16:05 +02:00
this.scaleFaces = 0.8;
2020-10-12 01:22:43 +02:00
}
2020-10-13 04:01:35 +02:00
async getBoundingBoxes(inputImage) {
2020-10-16 16:12:12 +02:00
// sanity check on input
if ((!inputImage) || (inputImage.isDisposedInternal) || (inputImage.shape.length !== 4) || (inputImage.shape[1] < 1) || (inputImage.shape[2] < 1)) return null;
2020-10-12 01:22:43 +02:00
const [detectedOutputs, boxes, scores] = tf.tidy(() => {
const resizedImage = inputImage.resizeBilinear([this.width, this.height]);
2020-11-06 21:35:58 +01:00
// const normalizedImage = tf.mul(tf.sub(resizedImage.div(255), 0.5), 2);
const normalizedImage = resizedImage.div(127.5).sub(0.5);
2020-10-12 01:22:43 +02:00
const batchedPrediction = this.blazeFaceModel.predict(normalizedImage);
2020-10-16 02:20:37 +02:00
let prediction;
// are we using tfhub or pinto converted model?
if (Array.isArray(batchedPrediction)) {
const sorted = batchedPrediction.sort((a, b) => a.size - b.size);
const concat384 = tf.concat([sorted[0], sorted[2]], 2); // dim: 384, 1 + 16
const concat512 = tf.concat([sorted[1], sorted[3]], 2); // dim: 512, 1 + 16
const concat = tf.concat([concat512, concat384], 1);
prediction = concat.squeeze(0);
} else {
prediction = batchedPrediction.squeeze(); // when using tfhub model
}
2020-10-12 01:22:43 +02:00
const decodedBounds = decodeBounds(prediction, this.anchors, this.inputSize);
const logits = tf.slice(prediction, [0, 0], [-1, 1]);
2020-11-08 18:26:45 +01:00
const scoresOut = tf.sigmoid(logits).squeeze();
2020-10-12 01:22:43 +02:00
return [prediction, decodedBounds, scoresOut];
});
const boxIndicesTensor = await tf.image.nonMaxSuppressionAsync(boxes, scores, this.config.face.detector.maxFaces, this.config.face.detector.iouThreshold, this.config.face.detector.scoreThreshold);
const boxIndices = boxIndicesTensor.arraySync();
2020-10-12 01:22:43 +02:00
boxIndicesTensor.dispose();
const boundingBoxesMap = boxIndices.map((boxIndex) => tf.slice(boxes, [boxIndex, 0], [1, -1]));
const boundingBoxes = boundingBoxesMap.map((boundingBox) => {
const vals = boundingBox.arraySync();
2020-10-13 04:01:35 +02:00
boundingBox.dispose();
return vals;
});
2020-11-09 20:26:10 +01:00
const scoresVal = scores.dataSync();
2021-02-08 18:47:38 +01:00
const annotatedBoxes: Array<{ box: any, landmarks: any, anchor: any, confidence: number }> = [];
2020-11-26 16:37:04 +01:00
for (let i = 0; i < boundingBoxes.length; i++) {
const boxIndex = boxIndices[i];
2020-11-09 20:26:10 +01:00
const confidence = scoresVal[boxIndex];
if (confidence > this.config.face.detector.minConfidence) {
2020-11-09 20:26:10 +01:00
const box = createBox(boundingBoxes[i]);
const anchor = this.anchorsData[boxIndex];
const landmarks = tf.tidy(() => tf.slice(detectedOutputs, [boxIndex, NUM_LANDMARKS - 1], [1, -1]).squeeze().reshape([NUM_LANDMARKS, -1]));
annotatedBoxes.push({ box, landmarks, anchor, confidence });
}
2020-10-12 01:22:43 +02:00
}
detectedOutputs.dispose();
2020-10-12 01:22:43 +02:00
boxes.dispose();
scores.dispose();
detectedOutputs.dispose();
return {
boxes: annotatedBoxes,
2020-10-16 02:20:37 +02:00
scaleFactor: [inputImage.shape[2] / this.width, inputImage.shape[1] / this.height],
2020-10-12 01:22:43 +02:00
};
}
}
2020-10-13 04:01:35 +02:00
2021-02-08 17:39:09 +01:00
export async function load(config) {
const blazeface = await tf.loadGraphModel(config.face.detector.modelPath, { fromTFHub: config.face.detector.modelPath.includes('tfhub.dev') });
2020-10-13 04:01:35 +02:00
const model = new BlazeFaceModel(blazeface, config);
2021-03-02 17:27:42 +01:00
if (config.debug) log(`load model: ${config.face.detector.modelPath.match(/\/(.*)\./)[1]}`);
2020-10-13 04:01:35 +02:00
return model;
}