From 97fcf059cd2728d506f59eac0e5c024a61eed103 Mon Sep 17 00:00:00 2001 From: Marco Godoy Date: Mon, 19 Jul 2021 13:27:16 -0400 Subject: [PATCH] proposal #141 --- src/blazeface/blazeface.ts | 7 +++++-- src/blazeface/facepipeline.ts | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/blazeface/blazeface.ts b/src/blazeface/blazeface.ts index c4de09c5..6bd603e5 100644 --- a/src/blazeface/blazeface.ts +++ b/src/blazeface/blazeface.ts @@ -1,4 +1,4 @@ -import { log, join } from '../helpers'; +import { log, join, mergeDeep } from '../helpers'; import * as tf from '../../dist/tfjs.esm.js'; import * as box from './box'; import * as util from './util'; @@ -37,7 +37,7 @@ export class BlazeFaceModel { this.config = config; } - async getBoundingBoxes(inputImage: Tensor) { + async getBoundingBoxes(inputImage: Tensor, userConfig: Config) { // sanity check on input // @ts-ignore isDisposed is internal property if ((!inputImage) || (inputImage.isDisposedInternal) || (inputImage.shape.length !== 4) || (inputImage.shape[1] < 1) || (inputImage.shape[2] < 1)) return null; @@ -60,6 +60,9 @@ export class BlazeFaceModel { const scoresOut = tf.sigmoid(logits).squeeze().dataSync(); return [batchOut, boxesOut, scoresOut]; }); + + this.config = mergeDeep(this.config, userConfig) as Config; + const nmsTensor = await tf.image.nonMaxSuppressionAsync(boxes, scores, this.config.face.detector.maxDetected, this.config.face.detector.iouThreshold, this.config.face.detector.minConfidence); const nms = nmsTensor.arraySync(); nmsTensor.dispose(); diff --git a/src/blazeface/facepipeline.ts b/src/blazeface/facepipeline.ts index 585f69f4..a2bf79da 100644 --- a/src/blazeface/facepipeline.ts +++ b/src/blazeface/facepipeline.ts @@ -158,7 +158,7 @@ export class Pipeline { // run new detector every skipFrames unless we only want box to start with let detector; if ((this.skipped === 0) || (this.skipped > config.face.detector.skipFrames) || !config.face.mesh.enabled || !config.skipFrame) { - detector = await this.boundingBoxDetector.getBoundingBoxes(input); + detector = await this.boundingBoxDetector.getBoundingBoxes(input, config); this.skipped = 0; } if (config.skipFrame) this.skipped++;