face-api/src/ssdMobilenetv1/nonMaxSuppression.ts

57 lines
2.1 KiB
TypeScript
Raw Normal View History

2020-12-23 18:58:47 +01:00
import * as tf from '../../dist/tfjs.esm';
2020-08-26 00:24:48 +02:00
2020-12-23 17:26:55 +01:00
function IOU(boxes: tf.Tensor2D, i: number, j: number) {
const boxesData = boxes.arraySync();
const yminI = Math.min(boxesData[i][0], boxesData[i][2]);
const xminI = Math.min(boxesData[i][1], boxesData[i][3]);
const ymaxI = Math.max(boxesData[i][0], boxesData[i][2]);
const xmaxI = Math.max(boxesData[i][1], boxesData[i][3]);
const yminJ = Math.min(boxesData[j][0], boxesData[j][2]);
const xminJ = Math.min(boxesData[j][1], boxesData[j][3]);
const ymaxJ = Math.max(boxesData[j][0], boxesData[j][2]);
const xmaxJ = Math.max(boxesData[j][1], boxesData[j][3]);
const areaI = (ymaxI - yminI) * (xmaxI - xminI);
const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
2021-03-19 23:46:36 +01:00
if (areaI <= 0 || areaJ <= 0) return 0.0;
2020-12-23 17:26:55 +01:00
const intersectionYmin = Math.max(yminI, yminJ);
const intersectionXmin = Math.max(xminI, xminJ);
const intersectionYmax = Math.min(ymaxI, ymaxJ);
const intersectionXmax = Math.min(xmaxI, xmaxJ);
2021-03-19 23:46:36 +01:00
const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * Math.max(intersectionXmax - intersectionXmin, 0.0);
2020-12-23 17:26:55 +01:00
return intersectionArea / (areaI + areaJ - intersectionArea);
}
2020-08-26 00:24:48 +02:00
export function nonMaxSuppression(
boxes: tf.Tensor2D,
scores: number[],
maxOutputSize: number,
iouThreshold: number,
2020-12-23 17:26:55 +01:00
scoreThreshold: number,
2020-08-26 00:24:48 +02:00
): number[] {
2020-12-23 17:26:55 +01:00
const numBoxes = boxes.shape[0];
2021-03-19 23:46:36 +01:00
const outputSize = Math.min(maxOutputSize, numBoxes);
2020-08-26 00:24:48 +02:00
const candidates = scores
.map((score, boxIndex) => ({ score, boxIndex }))
2020-12-23 17:26:55 +01:00
.filter((c) => c.score > scoreThreshold)
.sort((c1, c2) => c2.score - c1.score);
2020-08-26 00:24:48 +02:00
2020-12-23 17:26:55 +01:00
const suppressFunc = (x: number) => (x <= iouThreshold ? 1 : 0);
const selected: number[] = [];
2020-08-26 00:24:48 +02:00
2020-12-23 17:26:55 +01:00
candidates.forEach((c) => {
2021-01-24 17:08:04 +01:00
if (selected.length >= outputSize) return;
2020-12-23 17:26:55 +01:00
const originalScore = c.score;
2020-08-26 00:24:48 +02:00
for (let j = selected.length - 1; j >= 0; --j) {
2020-12-23 17:26:55 +01:00
const iou = IOU(boxes, c.boxIndex, selected[j]);
if (iou === 0.0) continue;
c.score *= suppressFunc(iou);
if (c.score <= scoreThreshold) break;
2020-08-26 00:24:48 +02:00
}
if (originalScore === c.score) {
2020-12-23 17:26:55 +01:00
selected.push(c.boxIndex);
2020-08-26 00:24:48 +02:00
}
2020-12-23 17:26:55 +01:00
});
return selected;
2020-08-26 00:24:48 +02:00
}