human/src/blazeface/blazeface.ts

96 lines
4.3 KiB
TypeScript
Raw Normal View History

2021-04-09 14:07:58 +02:00
import { log, join } from '../helpers';
2020-11-18 14:26:28 +01:00
import * as tf from '../../dist/tfjs.esm.js';
2021-04-25 22:56:10 +02:00
import * as box from './box';
import * as util from './util';
2021-05-22 20:53:51 +02:00
import { Config } from '../config';
import { Tensor, GraphModel } from '../tfjs/types';
2020-10-12 01:22:43 +02:00
2021-04-25 22:56:10 +02:00
const keypointsCount = 6;
2020-10-13 04:01:35 +02:00
2020-10-12 01:22:43 +02:00
function decodeBounds(boxOutputs, anchors, inputSize) {
const boxStarts = tf.slice(boxOutputs, [0, 1], [-1, 2]);
const centers = tf.add(boxStarts, anchors);
const boxSizes = tf.slice(boxOutputs, [0, 3], [-1, 2]);
const boxSizesNormalized = tf.div(boxSizes, inputSize);
const centersNormalized = tf.div(centers, inputSize);
const halfBoxSize = tf.div(boxSizesNormalized, 2);
const starts = tf.sub(centersNormalized, halfBoxSize);
const ends = tf.add(centersNormalized, halfBoxSize);
const startNormalized = tf.mul(starts, inputSize);
const endNormalized = tf.mul(ends, inputSize);
const concatAxis = 1;
return tf.concat2d([startNormalized, endNormalized], concatAxis);
}
2020-10-13 04:01:35 +02:00
2021-02-08 17:39:09 +01:00
export class BlazeFaceModel {
model: GraphModel;
2021-05-22 20:53:51 +02:00
anchorsData: [number, number][];
anchors: Tensor;
inputSize: number;
2021-05-22 20:53:51 +02:00
config: Config;
2021-02-08 17:39:09 +01:00
2021-06-03 15:41:53 +02:00
constructor(model, config: Config) {
this.model = model;
2021-04-25 22:56:10 +02:00
this.anchorsData = util.generateAnchors(model.inputs[0].shape[1]);
2020-10-12 01:22:43 +02:00
this.anchors = tf.tensor2d(this.anchorsData);
this.inputSize = model.inputs[0].shape[2];
2020-11-06 21:35:58 +01:00
this.config = config;
2020-10-12 01:22:43 +02:00
}
2021-06-03 15:41:53 +02:00
async getBoundingBoxes(inputImage: Tensor) {
2020-10-16 16:12:12 +02:00
// sanity check on input
2021-06-03 15:41:53 +02:00
// @ts-ignore isDisposed is internal property
2020-10-16 16:12:12 +02:00
if ((!inputImage) || (inputImage.isDisposedInternal) || (inputImage.shape.length !== 4) || (inputImage.shape[1] < 1) || (inputImage.shape[2] < 1)) return null;
const [batch, boxes, scores] = tf.tidy(() => {
2021-06-05 18:59:11 +02:00
const resizedImage = tf.image.resizeBilinear(inputImage, [this.inputSize, this.inputSize]);
const normalizedImage = resizedImage.div(127.5).sub(0.5);
2021-04-28 14:55:26 +02:00
const res = this.model.execute(normalizedImage);
let batchOut;
2021-04-28 14:55:26 +02:00
if (Array.isArray(res)) { // are we using tfhub or pinto converted model?
const sorted = res.sort((a, b) => a.size - b.size);
2020-10-16 02:20:37 +02:00
const concat384 = tf.concat([sorted[0], sorted[2]], 2); // dim: 384, 1 + 16
const concat512 = tf.concat([sorted[1], sorted[3]], 2); // dim: 512, 1 + 16
const concat = tf.concat([concat512, concat384], 1);
batchOut = concat.squeeze(0);
2020-10-16 02:20:37 +02:00
} else {
2021-06-05 18:59:11 +02:00
batchOut = tf.squeeze(res); // when using tfhub model
2020-10-16 02:20:37 +02:00
}
const boxesOut = decodeBounds(batchOut, this.anchors, [this.inputSize, this.inputSize]);
const logits = tf.slice(batchOut, [0, 0], [-1, 1]);
2021-04-28 14:55:26 +02:00
const scoresOut = tf.sigmoid(logits).squeeze().dataSync();
return [batchOut, boxesOut, scoresOut];
2020-10-12 01:22:43 +02:00
});
2021-04-28 14:55:26 +02:00
const nmsTensor = await tf.image.nonMaxSuppressionAsync(boxes, scores, this.config.face.detector.maxDetected, this.config.face.detector.iouThreshold, this.config.face.detector.minConfidence);
const nms = nmsTensor.arraySync();
nmsTensor.dispose();
const annotatedBoxes: Array<{ box: { startPoint: Tensor, endPoint: Tensor }, landmarks: Tensor, anchor: number[], confidence: number }> = [];
2021-04-28 14:55:26 +02:00
for (let i = 0; i < nms.length; i++) {
const confidence = scores[nms[i]];
if (confidence > this.config.face.detector.minConfidence) {
2021-04-28 14:55:26 +02:00
const boundingBox = tf.slice(boxes, [nms[i], 0], [1, -1]);
const localBox = box.createBox(boundingBox);
boundingBox.dispose();
const anchor = this.anchorsData[nms[i]];
const landmarks = tf.tidy(() => tf.slice(batch, [nms[i], keypointsCount - 1], [1, -1]).squeeze().reshape([keypointsCount, -1]));
2021-04-25 22:56:10 +02:00
annotatedBoxes.push({ box: localBox, landmarks, anchor, confidence });
2020-11-09 20:26:10 +01:00
}
2020-10-12 01:22:43 +02:00
}
2021-04-28 14:55:26 +02:00
// boundingBoxes.forEach((t) => t.dispose());
batch.dispose();
2020-10-12 01:22:43 +02:00
boxes.dispose();
2021-04-28 14:55:26 +02:00
// scores.dispose();
2020-10-12 01:22:43 +02:00
return {
boxes: annotatedBoxes,
scaleFactor: [inputImage.shape[2] / this.inputSize, inputImage.shape[1] / this.inputSize],
2020-10-12 01:22:43 +02:00
};
}
}
2020-10-13 04:01:35 +02:00
2021-06-03 15:41:53 +02:00
export async function load(config: Config) {
2021-04-09 14:07:58 +02:00
const model = await tf.loadGraphModel(join(config.modelBasePath, config.face.detector.modelPath), { fromTFHub: config.face.detector.modelPath.includes('tfhub.dev') });
const blazeFace = new BlazeFaceModel(model, config);
if (!model || !model.modelUrl) log('load model failed:', config.face.detector.modelPath);
else if (config.debug) log('load model:', model.modelUrl);
return blazeFace;
2020-10-13 04:01:35 +02:00
}