human/src/object/centernet.ts

101 lines
3.6 KiB
TypeScript
Raw Normal View History

2021-05-25 14:58:20 +02:00
/**
* CenterNet object detection model implementation
*
* Based on: [**NanoDet**](https://github.com/RangiLyu/nanodet)
2021-05-25 14:58:20 +02:00
*/
2021-10-22 22:09:52 +02:00
import { log, join, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { labels } from './labels';
2021-09-27 15:19:43 +02:00
import type { ObjectResult, Box } from '../result';
2021-09-13 19:28:35 +02:00
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
2021-09-27 19:58:13 +02:00
import { env } from '../util/env';
2021-09-17 17:23:00 +02:00
let model: GraphModel | null;
let inputSize = 0;
2021-09-12 05:54:35 +02:00
let last: ObjectResult[] = [];
2021-10-22 22:09:52 +02:00
let lastTime = 0;
let skipped = Number.MAX_SAFE_INTEGER;
2021-06-03 15:41:53 +02:00
export async function load(config: Config): Promise<GraphModel> {
2021-09-17 17:23:00 +02:00
if (env.initial) model = null;
if (!model) {
2021-11-05 18:36:53 +01:00
// fakeOps(['floormod'], config);
2021-09-17 17:23:00 +02:00
model = await tf.loadGraphModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel;
const inputs = Object.values(model.modelSignature['inputs']);
2021-09-17 17:23:00 +02:00
inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
if (!model || !model['modelUrl']) log('load model failed:', config.object.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}
2021-09-17 17:23:00 +02:00
async function process(res: Tensor | null, outputShape, config: Config) {
2021-05-30 18:03:34 +02:00
if (!res) return [];
2021-09-12 05:54:35 +02:00
const results: Array<ObjectResult> = [];
2021-08-12 00:59:02 +02:00
const detections = await res.array();
const squeezeT = tf.squeeze(res);
2021-07-29 22:06:03 +02:00
tf.dispose(res);
const arr = tf.split(squeezeT, 6, 1); // x1, y1, x2, y2, score, class
2021-07-29 22:06:03 +02:00
tf.dispose(squeezeT);
2021-05-30 18:03:34 +02:00
const stackT = tf.stack([arr[1], arr[0], arr[3], arr[2]], 1); // reorder dims as tf.nms expects y, x
2021-08-12 00:59:02 +02:00
const boxesT = tf.squeeze(stackT);
2021-09-13 19:28:35 +02:00
tf.dispose(stackT);
2021-08-12 00:59:02 +02:00
const scoresT = tf.squeeze(arr[4]);
const classesT = tf.squeeze(arr[5]);
2021-07-29 22:06:03 +02:00
arr.forEach((t) => tf.dispose(t));
const nmsT = await tf.image.nonMaxSuppressionAsync(boxesT, scoresT, config.object.maxDetected, config.object.iouThreshold, config.object.minConfidence);
2021-07-29 22:06:03 +02:00
tf.dispose(boxesT);
tf.dispose(scoresT);
tf.dispose(classesT);
2021-08-12 15:31:16 +02:00
const nms = await nmsT.data();
2021-07-29 22:06:03 +02:00
tf.dispose(nmsT);
2021-05-24 13:16:38 +02:00
let i = 0;
for (const id of nms) {
2021-05-30 18:03:34 +02:00
const score = Math.trunc(100 * detections[0][id][4]) / 100;
const classVal = detections[0][id][5];
const label = labels[classVal].label;
2021-06-11 22:12:24 +02:00
const [x, y] = [
detections[0][id][0] / inputSize,
detections[0][id][1] / inputSize,
2021-06-11 22:12:24 +02:00
];
2021-09-27 15:19:43 +02:00
const boxRaw: Box = [
2021-06-11 22:12:24 +02:00
x,
y,
detections[0][id][2] / inputSize - x,
detections[0][id][3] / inputSize - y,
2021-09-27 15:19:43 +02:00
];
const box: Box = [
Math.trunc(boxRaw[0] * outputShape[0]),
Math.trunc(boxRaw[1] * outputShape[1]),
Math.trunc(boxRaw[2] * outputShape[0]),
Math.trunc(boxRaw[3] * outputShape[1]),
2021-09-27 15:19:43 +02:00
];
2021-05-24 13:16:38 +02:00
results.push({ id: i++, score, class: classVal, label, box, boxRaw });
}
return results;
}
2021-09-12 05:54:35 +02:00
export async function predict(input: Tensor, config: Config): Promise<ObjectResult[]> {
2021-10-23 15:38:52 +02:00
const skipTime = (config.object.skipTime || 0) > (now() - lastTime);
const skipFrame = skipped < (config.object.skipFrames || 0);
if (config.skipAllowed && skipTime && skipFrame && (last.length > 0)) {
skipped++;
return last;
}
skipped = 0;
return new Promise(async (resolve) => {
2021-05-30 18:03:34 +02:00
const outputSize = [input.shape[2], input.shape[1]];
2021-09-17 17:23:00 +02:00
const resize = tf.image.resizeBilinear(input, [inputSize, inputSize]);
const objectT = config.object.enabled ? model?.execute(resize, ['tower_0/detections']) as Tensor : null;
2021-10-22 22:09:52 +02:00
lastTime = now();
2021-07-29 22:06:03 +02:00
tf.dispose(resize);
2021-09-17 17:23:00 +02:00
const obj = await process(objectT, outputSize, config);
last = obj;
2021-09-13 19:28:35 +02:00
resolve(obj);
});
}