face-api/src/ops/nonMaxSuppression.ts

40 lines
915 B
TypeScript
Raw Normal View History

2020-08-18 13:54:53 +02:00
import { Box } from '../classes/Box';
import { iou } from './iou';
export function nonMaxSuppression(
boxes: Box[],
scores: number[],
iouThreshold: number,
2021-06-04 15:17:04 +02:00
isIOU = true,
2020-08-18 13:54:53 +02:00
): number[] {
let indicesSortedByScore = scores
.map((score, boxIndex) => ({ score, boxIndex }))
.sort((c1, c2) => c1.score - c2.score)
2020-12-23 17:26:55 +01:00
.map((c) => c.boxIndex);
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
const pick: number[] = [];
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
while (indicesSortedByScore.length > 0) {
const curr = indicesSortedByScore.pop() as number;
pick.push(curr);
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
const indices = indicesSortedByScore;
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
const outputs: number[] = [];
2020-08-18 13:54:53 +02:00
for (let i = 0; i < indices.length; i++) {
2020-12-23 17:26:55 +01:00
const idx = indices[i];
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
const currBox = boxes[curr];
const idxBox = boxes[idx];
2020-08-18 13:54:53 +02:00
2020-12-23 17:26:55 +01:00
outputs.push(iou(currBox, idxBox, isIOU));
2020-08-18 13:54:53 +02:00
}
indicesSortedByScore = indicesSortedByScore.filter(
2020-12-23 17:26:55 +01:00
(_, j) => outputs[j] <= iouThreshold,
);
2020-08-18 13:54:53 +02:00
}
2020-12-23 17:26:55 +01:00
return pick;
}