2021-05-25 14:58:20 +02:00
|
|
|
/**
|
2021-09-25 17:51:15 +02:00
|
|
|
* CenterNet object detection model implementation
|
|
|
|
*
|
|
|
|
* Based on: [**NanoDet**](https://github.com/RangiLyu/nanodet)
|
2021-05-25 14:58:20 +02:00
|
|
|
*/
|
|
|
|
|
2022-01-17 17:03:21 +01:00
|
|
|
import { log, now } from '../util/util';
|
2021-05-19 14:27:28 +02:00
|
|
|
import * as tf from '../../dist/tfjs.esm.js';
|
2022-01-16 15:49:55 +01:00
|
|
|
import { loadModel } from '../tfjs/load';
|
2021-05-19 14:27:28 +02:00
|
|
|
import { labels } from './labels';
|
2021-12-15 15:26:32 +01:00
|
|
|
import type { ObjectResult, ObjectType, Box } from '../result';
|
2021-09-13 19:28:35 +02:00
|
|
|
import type { GraphModel, Tensor } from '../tfjs/types';
|
|
|
|
import type { Config } from '../config';
|
2021-09-27 19:58:13 +02:00
|
|
|
import { env } from '../util/env';
|
2021-05-19 14:27:28 +02:00
|
|
|
|
2021-09-17 17:23:00 +02:00
|
|
|
let model: GraphModel | null;
|
|
|
|
let inputSize = 0;
|
2021-09-12 05:54:35 +02:00
|
|
|
let last: ObjectResult[] = [];
|
2021-10-22 22:09:52 +02:00
|
|
|
let lastTime = 0;
|
2021-05-19 14:27:28 +02:00
|
|
|
let skipped = Number.MAX_SAFE_INTEGER;
|
|
|
|
|
2021-06-03 15:41:53 +02:00
|
|
|
export async function load(config: Config): Promise<GraphModel> {
|
2021-09-17 17:23:00 +02:00
|
|
|
if (env.initial) model = null;
|
2021-05-19 14:27:28 +02:00
|
|
|
if (!model) {
|
2021-11-05 18:36:53 +01:00
|
|
|
// fakeOps(['floormod'], config);
|
2022-01-17 17:03:21 +01:00
|
|
|
model = await loadModel(config.object.modelPath);
|
2021-05-19 14:27:28 +02:00
|
|
|
const inputs = Object.values(model.modelSignature['inputs']);
|
2021-09-17 17:23:00 +02:00
|
|
|
inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
|
|
|
|
} else if (config.debug) log('cached model:', model['modelUrl']);
|
2021-05-19 14:27:28 +02:00
|
|
|
return model;
|
|
|
|
}
|
|
|
|
|
2021-12-28 17:39:54 +01:00
|
|
|
async function process(res: Tensor | null, outputShape: [number, number], config: Config) {
|
2021-05-30 18:03:34 +02:00
|
|
|
if (!res) return [];
|
2021-11-17 02:16:49 +01:00
|
|
|
const t: Record<string, Tensor> = {};
|
2021-09-12 05:54:35 +02:00
|
|
|
const results: Array<ObjectResult> = [];
|
2021-12-28 17:39:54 +01:00
|
|
|
const detections = await res.array() as number[][][];
|
2021-11-17 02:16:49 +01:00
|
|
|
t.squeeze = tf.squeeze(res);
|
|
|
|
const arr = tf.split(t.squeeze, 6, 1) as Tensor[]; // x1, y1, x2, y2, score, class
|
|
|
|
t.stack = tf.stack([arr[1], arr[0], arr[3], arr[2]], 1); // reorder dims as tf.nms expects y, x
|
|
|
|
t.boxes = tf.squeeze(t.stack);
|
|
|
|
t.scores = tf.squeeze(arr[4]);
|
|
|
|
t.classes = tf.squeeze(arr[5]);
|
|
|
|
tf.dispose([res, ...arr]);
|
2022-08-10 19:44:38 +02:00
|
|
|
t.nms = tf.image.nonMaxSuppression(t.boxes, t.scores, config.object.maxDetected, config.object.iouThreshold, (config.object.minConfidence || 0));
|
2021-11-17 02:16:49 +01:00
|
|
|
const nms = await t.nms.data();
|
2021-05-24 13:16:38 +02:00
|
|
|
let i = 0;
|
2021-11-17 02:16:49 +01:00
|
|
|
for (const id of Array.from(nms)) {
|
2021-05-30 18:03:34 +02:00
|
|
|
const score = Math.trunc(100 * detections[0][id][4]) / 100;
|
2021-05-19 14:27:28 +02:00
|
|
|
const classVal = detections[0][id][5];
|
2021-12-15 15:26:32 +01:00
|
|
|
const label = labels[classVal].label as ObjectType;
|
2021-06-11 22:12:24 +02:00
|
|
|
const [x, y] = [
|
2021-05-19 14:27:28 +02:00
|
|
|
detections[0][id][0] / inputSize,
|
|
|
|
detections[0][id][1] / inputSize,
|
2021-06-11 22:12:24 +02:00
|
|
|
];
|
2021-09-27 15:19:43 +02:00
|
|
|
const boxRaw: Box = [
|
2021-06-11 22:12:24 +02:00
|
|
|
x,
|
|
|
|
y,
|
|
|
|
detections[0][id][2] / inputSize - x,
|
|
|
|
detections[0][id][3] / inputSize - y,
|
2021-09-27 15:19:43 +02:00
|
|
|
];
|
|
|
|
const box: Box = [
|
2021-05-19 14:27:28 +02:00
|
|
|
Math.trunc(boxRaw[0] * outputShape[0]),
|
|
|
|
Math.trunc(boxRaw[1] * outputShape[1]),
|
|
|
|
Math.trunc(boxRaw[2] * outputShape[0]),
|
|
|
|
Math.trunc(boxRaw[3] * outputShape[1]),
|
2021-09-27 15:19:43 +02:00
|
|
|
];
|
2021-05-24 13:16:38 +02:00
|
|
|
results.push({ id: i++, score, class: classVal, label, box, boxRaw });
|
2021-05-19 14:27:28 +02:00
|
|
|
}
|
2021-11-17 02:16:49 +01:00
|
|
|
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
|
2021-05-19 14:27:28 +02:00
|
|
|
return results;
|
|
|
|
}
|
|
|
|
|
2021-09-12 05:54:35 +02:00
|
|
|
export async function predict(input: Tensor, config: Config): Promise<ObjectResult[]> {
|
2021-10-23 15:38:52 +02:00
|
|
|
const skipTime = (config.object.skipTime || 0) > (now() - lastTime);
|
|
|
|
const skipFrame = skipped < (config.object.skipFrames || 0);
|
|
|
|
if (config.skipAllowed && skipTime && skipFrame && (last.length > 0)) {
|
2021-05-19 14:27:28 +02:00
|
|
|
skipped++;
|
|
|
|
return last;
|
|
|
|
}
|
|
|
|
skipped = 0;
|
|
|
|
return new Promise(async (resolve) => {
|
2021-12-28 17:39:54 +01:00
|
|
|
const outputSize = [input.shape[2] || 0, input.shape[1] || 0] as [number, number];
|
2021-09-17 17:23:00 +02:00
|
|
|
const resize = tf.image.resizeBilinear(input, [inputSize, inputSize]);
|
|
|
|
const objectT = config.object.enabled ? model?.execute(resize, ['tower_0/detections']) as Tensor : null;
|
2021-10-22 22:09:52 +02:00
|
|
|
lastTime = now();
|
2021-07-29 22:06:03 +02:00
|
|
|
tf.dispose(resize);
|
2021-05-19 14:27:28 +02:00
|
|
|
|
2021-09-17 17:23:00 +02:00
|
|
|
const obj = await process(objectT, outputSize, config);
|
2021-05-19 14:27:28 +02:00
|
|
|
last = obj;
|
2021-09-13 19:28:35 +02:00
|
|
|
|
2021-05-19 14:27:28 +02:00
|
|
|
resolve(obj);
|
|
|
|
});
|
|
|
|
}
|