2021-05-25 14:58:20 +02:00
|
|
|
/**
|
|
|
|
* CenterNet object detection module
|
|
|
|
*/
|
|
|
|
|
2021-05-19 14:27:28 +02:00
|
|
|
import { log, join } from '../helpers';
|
|
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
|
|
|
import { labels } from './labels';
|
2021-05-22 18:33:19 +02:00
|
|
|
import { Item } from '../result';
|
2021-05-19 14:27:28 +02:00
|
|
|
|
|
|
|
let model;
|
2021-05-22 18:33:19 +02:00
|
|
|
let last: Item[] = [];
|
2021-05-19 14:27:28 +02:00
|
|
|
let skipped = Number.MAX_SAFE_INTEGER;
|
|
|
|
|
|
|
|
export async function load(config) {
|
|
|
|
if (!model) {
|
|
|
|
model = await tf.loadGraphModel(join(config.modelBasePath, config.object.modelPath));
|
|
|
|
const inputs = Object.values(model.modelSignature['inputs']);
|
|
|
|
model.inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : null;
|
|
|
|
if (!model.inputSize) throw new Error(`Human: Cannot determine model inputSize: ${config.object.modelPath}`);
|
|
|
|
if (!model || !model.modelUrl) log('load model failed:', config.object.modelPath);
|
|
|
|
else if (config.debug) log('load model:', model.modelUrl);
|
|
|
|
} else if (config.debug) log('cached model:', model.modelUrl);
|
|
|
|
return model;
|
|
|
|
}
|
|
|
|
|
|
|
|
async function process(res, inputSize, outputShape, config) {
|
2021-05-30 18:03:34 +02:00
|
|
|
if (!res) return [];
|
2021-05-24 13:16:38 +02:00
|
|
|
const results: Array<Item> = [];
|
2021-05-19 14:27:28 +02:00
|
|
|
const detections = res.arraySync();
|
|
|
|
const squeezeT = tf.squeeze(res);
|
|
|
|
res.dispose();
|
|
|
|
const arr = tf.split(squeezeT, 6, 1); // x1, y1, x2, y2, score, class
|
|
|
|
squeezeT.dispose();
|
2021-05-30 18:03:34 +02:00
|
|
|
const stackT = tf.stack([arr[1], arr[0], arr[3], arr[2]], 1); // reorder dims as tf.nms expects y, x
|
2021-05-19 14:27:28 +02:00
|
|
|
const boxesT = stackT.squeeze();
|
|
|
|
const scoresT = arr[4].squeeze();
|
|
|
|
const classesT = arr[5].squeeze();
|
|
|
|
arr.forEach((t) => t.dispose());
|
|
|
|
const nmsT = await tf.image.nonMaxSuppressionAsync(boxesT, scoresT, config.object.maxDetected, config.object.iouThreshold, config.object.minConfidence);
|
|
|
|
boxesT.dispose();
|
|
|
|
scoresT.dispose();
|
|
|
|
classesT.dispose();
|
|
|
|
const nms = nmsT.dataSync();
|
|
|
|
nmsT.dispose();
|
2021-05-24 13:16:38 +02:00
|
|
|
let i = 0;
|
2021-05-19 14:27:28 +02:00
|
|
|
for (const id of nms) {
|
2021-05-30 18:03:34 +02:00
|
|
|
const score = Math.trunc(100 * detections[0][id][4]) / 100;
|
2021-05-19 14:27:28 +02:00
|
|
|
const classVal = detections[0][id][5];
|
|
|
|
const label = labels[classVal].label;
|
|
|
|
const boxRaw = [
|
|
|
|
detections[0][id][0] / inputSize,
|
|
|
|
detections[0][id][1] / inputSize,
|
|
|
|
detections[0][id][2] / inputSize,
|
|
|
|
detections[0][id][3] / inputSize,
|
|
|
|
];
|
|
|
|
const box = [
|
|
|
|
Math.trunc(boxRaw[0] * outputShape[0]),
|
|
|
|
Math.trunc(boxRaw[1] * outputShape[1]),
|
|
|
|
Math.trunc(boxRaw[2] * outputShape[0]),
|
|
|
|
Math.trunc(boxRaw[3] * outputShape[1]),
|
|
|
|
];
|
2021-05-24 13:16:38 +02:00
|
|
|
results.push({ id: i++, score, class: classVal, label, box, boxRaw });
|
2021-05-19 14:27:28 +02:00
|
|
|
}
|
|
|
|
return results;
|
|
|
|
}
|
|
|
|
|
2021-05-30 18:03:34 +02:00
|
|
|
export async function predict(input, config): Promise<Item[]> {
|
2021-05-19 14:27:28 +02:00
|
|
|
if ((skipped < config.object.skipFrames) && config.skipFrame && (last.length > 0)) {
|
|
|
|
skipped++;
|
|
|
|
return last;
|
|
|
|
}
|
|
|
|
skipped = 0;
|
|
|
|
return new Promise(async (resolve) => {
|
2021-05-30 18:03:34 +02:00
|
|
|
const outputSize = [input.shape[2], input.shape[1]];
|
|
|
|
const resize = tf.image.resizeBilinear(input, [model.inputSize, model.inputSize]);
|
|
|
|
const objectT = config.object.enabled ? model.execute(resize, ['tower_0/detections']) : null;
|
2021-05-19 14:27:28 +02:00
|
|
|
resize.dispose();
|
|
|
|
|
|
|
|
const obj = await process(objectT, model.inputSize, outputSize, config);
|
|
|
|
last = obj;
|
|
|
|
resolve(obj);
|
|
|
|
});
|
|
|
|
}
|