face-api/src/tinyYolov2/TinyYolov2.ts

59 lines
1.8 KiB
TypeScript
Raw Normal View History

2020-12-23 18:58:47 +01:00
import * as tf from '../../dist/tfjs.esm';
2020-08-18 13:54:53 +02:00
2020-12-19 17:46:41 +01:00
import { FaceDetection, Point } from '../classes/index';
2020-08-18 13:54:53 +02:00
import { ParamMapping } from '../common/types';
import { TNetInput } from '../dom/types';
import {
BOX_ANCHORS,
BOX_ANCHORS_SEPARABLE,
DEFAULT_MODEL_NAME,
DEFAULT_MODEL_NAME_SEPARABLE_CONV,
IOU_THRESHOLD,
MEAN_RGB_SEPARABLE,
} from './const';
import { TinyYolov2Base } from './TinyYolov2Base';
import { ITinyYolov2Options } from './TinyYolov2Options';
import { TinyYolov2NetParams } from './types';
export class TinyYolov2 extends TinyYolov2Base {
constructor(withSeparableConvs: boolean = true) {
2020-12-23 17:26:55 +01:00
const config = {
2020-08-18 13:54:53 +02:00
withSeparableConvs,
iouThreshold: IOU_THRESHOLD,
2020-12-23 17:26:55 +01:00
classes: ['face'],
...(withSeparableConvs
? {
anchors: BOX_ANCHORS_SEPARABLE,
meanRgb: MEAN_RGB_SEPARABLE,
}
: {
anchors: BOX_ANCHORS,
withClassScores: true,
}),
};
super(config);
2020-08-18 13:54:53 +02:00
}
public get withSeparableConvs(): boolean {
2020-12-23 17:26:55 +01:00
return this.config.withSeparableConvs;
2020-08-18 13:54:53 +02:00
}
public get anchors(): Point[] {
2020-12-23 17:26:55 +01:00
return this.config.anchors;
2020-08-18 13:54:53 +02:00
}
public async locateFaces(input: TNetInput, forwardParams: ITinyYolov2Options): Promise<FaceDetection[]> {
2020-12-23 17:26:55 +01:00
const objectDetections = await this.detect(input, forwardParams);
return objectDetections.map((det) => new FaceDetection(det.score, det.relativeBox, { width: det.imageWidth, height: det.imageHeight }));
2020-08-18 13:54:53 +02:00
}
protected getDefaultModelName(): string {
2020-12-23 17:26:55 +01:00
return this.withSeparableConvs ? DEFAULT_MODEL_NAME_SEPARABLE_CONV : DEFAULT_MODEL_NAME;
2020-08-18 13:54:53 +02:00
}
2021-01-12 16:14:33 +01:00
protected extractParamsFromWeightMap(weightMap: tf.NamedTensorMap): { params: TinyYolov2NetParams, paramMappings: ParamMapping[] } {
return super.extractParamsFromWeightMap(weightMap);
2020-08-18 13:54:53 +02:00
}
2020-12-23 17:26:55 +01:00
}