full re-lint and typings generation

pull/34/head
Vladimir Mandic 2020-12-23 11:26:55 -05:00
parent 2a7bf4080b
commit 0692fbf532
353 changed files with 7151 additions and 5729 deletions

47
.eslintrc.json Normal file
View File

@ -0,0 +1,47 @@
{
"globals": {},
"env": {
"browser": true,
"commonjs": true,
"es6": true,
"node": true,
"es2020": true
},
"parser": "@typescript-eslint/parser",
"parserOptions": { "ecmaVersion": 2020 },
"plugins": ["@typescript-eslint"],
"extends": [
"eslint:recommended",
"plugin:import/errors",
"plugin:import/warnings",
"plugin:import/typescript",
"plugin:node/recommended",
"plugin:promise/recommended",
"plugin:json/recommended-with-comments",
"airbnb-base"
],
"ignorePatterns": [ "node_modules", "types" ],
"settings": {
"import/resolver": {
"node": {
"extensions": [".js", ".ts"]
}
}
},
"rules": {
"max-len": [1, 275, 3],
"no-plusplus": "off",
"import/prefer-default-export": "off",
"node/no-unsupported-features/es-syntax": "off",
"import/no-cycle": "off",
"import/extensions": "off",
"node/no-missing-import": "off",
"no-underscore-dangle": "off",
"class-methods-use-this": "off",
"camelcase": "off",
"no-await-in-loop": "off",
"no-continue": "off",
"no-param-reassign": "off",
"prefer-destructuring": "off"
}
}

View File

@ -21,7 +21,8 @@ const tsconfig = {
noEmitOnError: false, noEmitOnError: false,
target: ts.ScriptTarget.ES2018, target: ts.ScriptTarget.ES2018,
module: ts.ModuleKind.ES2020, module: ts.ModuleKind.ES2020,
outFile: "dist/face-api.d.ts", // outFile: "dist/face-api.d.ts",
outDir: "types/",
declaration: true, declaration: true,
emitDeclarationOnly: true, emitDeclarationOnly: true,
emitDecoratorMetadata: true, emitDecoratorMetadata: true,

2109
dist/face-api.d.ts vendored

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

401
dist/face-api.esm.json vendored

File diff suppressed because it is too large Load Diff

4
dist/face-api.js vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

400
dist/face-api.json vendored

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

4
dist/tfjs.esm.js vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

13
dist/tfjs.esm.json vendored
View File

@ -20799,7 +20799,7 @@
] ]
}, },
"src/tfjs/tf-browser.js": { "src/tfjs/tf-browser.js": {
"bytes": 1784, "bytes": 1888,
"imports": [ "imports": [
{ {
"path": "node_modules/@tensorflow/tfjs/dist/index.js" "path": "node_modules/@tensorflow/tfjs/dist/index.js"
@ -20814,7 +20814,7 @@
"dist/tfjs.esm.js.map": { "dist/tfjs.esm.js.map": {
"imports": [], "imports": [],
"inputs": {}, "inputs": {},
"bytes": 1070617 "bytes": 1063970
}, },
"dist/tfjs.esm.js": { "dist/tfjs.esm.js": {
"imports": [], "imports": [],
@ -23195,7 +23195,7 @@
"bytesInOutput": 29890 "bytesInOutput": 29890
}, },
"node_modules/@tensorflow/tfjs-layers/dist/layers/convolutional_recurrent.js": { "node_modules/@tensorflow/tfjs-layers/dist/layers/convolutional_recurrent.js": {
"bytesInOutput": 9235 "bytesInOutput": 9196
}, },
"node_modules/@tensorflow/tfjs-layers/dist/layers/core.js": { "node_modules/@tensorflow/tfjs-layers/dist/layers/core.js": {
"bytesInOutput": 9961 "bytesInOutput": 9961
@ -24070,11 +24070,8 @@
"node_modules/@tensorflow/tfjs-backend-webgl/dist/version.js": { "node_modules/@tensorflow/tfjs-backend-webgl/dist/version.js": {
"bytesInOutput": 22 "bytesInOutput": 22
}, },
"node_modules/@tensorflow/tfjs-backend-webgl/dist/webgl.js": {
"bytesInOutput": 67
},
"node_modules/@tensorflow/tfjs-backend-webgl/dist/base.js": { "node_modules/@tensorflow/tfjs-backend-webgl/dist/base.js": {
"bytesInOutput": 113 "bytesInOutput": 85
}, },
"node_modules/@tensorflow/tfjs-backend-webgl/dist/index.js": { "node_modules/@tensorflow/tfjs-backend-webgl/dist/index.js": {
"bytesInOutput": 0 "bytesInOutput": 0
@ -25067,7 +25064,7 @@
"bytesInOutput": 0 "bytesInOutput": 0
} }
}, },
"bytes": 1572551 "bytes": 1572418
} }
} }
} }

1740
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -5,13 +5,14 @@
"main": "dist/face-api.node.js", "main": "dist/face-api.node.js",
"module": "dist/face-api.esm.js", "module": "dist/face-api.esm.js",
"browser": "dist/face-api.esm.js", "browser": "dist/face-api.esm.js",
"typings": "dist/face-api.d.ts", "types": "types/index.d.ts",
"engines": { "engines": {
"node": ">=12.0.0" "node": ">=12.0.0"
}, },
"scripts": { "scripts": {
"start": "node --trace-warnings example/node.js", "start": "node --trace-warnings example/node-singleprocess.js",
"build": "rimraf dist/* && node ./build.js" "build": "rimraf dist/* && node ./build.js",
"lint": "eslint src/**/*"
}, },
"keywords": [ "keywords": [
"tensorflow", "tensorflow",
@ -44,7 +45,15 @@
"@tensorflow/tfjs-node": "^2.8.1", "@tensorflow/tfjs-node": "^2.8.1",
"@tensorflow/tfjs-node-gpu": "^2.8.1", "@tensorflow/tfjs-node-gpu": "^2.8.1",
"@types/node": "^14.14.14", "@types/node": "^14.14.14",
"esbuild": "^0.8.24", "@typescript-eslint/eslint-plugin": "^4.11.0",
"@typescript-eslint/parser": "^4.11.0",
"esbuild": "^0.8.26",
"eslint": "^7.16.0",
"eslint-config-airbnb-base": "^14.2.1",
"eslint-plugin-import": "^2.22.1",
"eslint-plugin-json": "^2.1.2",
"eslint-plugin-node": "^11.1.0",
"eslint-plugin-promise": "^4.2.1",
"rimraf": "^3.0.2", "rimraf": "^3.0.2",
"tslib": "^2.0.3", "tslib": "^2.0.3",
"typescript": "^4.1.3" "typescript": "^4.1.3"

View File

@ -5,120 +5,118 @@ import { seperateWeightMaps } from '../faceProcessor/util';
import { TinyXception } from '../xception/TinyXception'; import { TinyXception } from '../xception/TinyXception';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap'; import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { AgeAndGenderPrediction, Gender, NetOutput, NetParams } from './types'; import {
AgeAndGenderPrediction, Gender, NetOutput, NetParams,
} from './types';
import { NeuralNetwork } from '../NeuralNetwork'; import { NeuralNetwork } from '../NeuralNetwork';
import { NetInput, TNetInput, toNetInput } from '../dom/index'; import { NetInput, TNetInput, toNetInput } from '../dom/index';
export class AgeGenderNet extends NeuralNetwork<NetParams> { export class AgeGenderNet extends NeuralNetwork<NetParams> {
private _faceFeatureExtractor: TinyXception private _faceFeatureExtractor: TinyXception
constructor(faceFeatureExtractor: TinyXception = new TinyXception(2)) { constructor(faceFeatureExtractor: TinyXception = new TinyXception(2)) {
super('AgeGenderNet') super('AgeGenderNet');
this._faceFeatureExtractor = faceFeatureExtractor this._faceFeatureExtractor = faceFeatureExtractor;
} }
public get faceFeatureExtractor(): TinyXception { public get faceFeatureExtractor(): TinyXception {
return this._faceFeatureExtractor return this._faceFeatureExtractor;
} }
public runNet(input: NetInput | tf.Tensor4D): NetOutput { public runNet(input: NetInput | tf.Tensor4D): NetOutput {
const { params } = this;
const { params } = this
if (!params) { if (!params) {
throw new Error(`${this._name} - load model before inference`) throw new Error(`${this._name} - load model before inference`);
} }
return tf.tidy(() => { return tf.tidy(() => {
const bottleneckFeatures = input instanceof NetInput const bottleneckFeatures = input instanceof NetInput
? this.faceFeatureExtractor.forwardInput(input) ? this.faceFeatureExtractor.forwardInput(input)
: input : input;
const pooled = tf.avgPool(bottleneckFeatures, [7, 7], [2, 2], 'valid').as2D(bottleneckFeatures.shape[0], -1) const pooled = tf.avgPool(bottleneckFeatures, [7, 7], [2, 2], 'valid').as2D(bottleneckFeatures.shape[0], -1);
const age = fullyConnectedLayer(pooled, params.fc.age).as1D() const age = fullyConnectedLayer(pooled, params.fc.age).as1D();
const gender = fullyConnectedLayer(pooled, params.fc.gender) const gender = fullyConnectedLayer(pooled, params.fc.gender);
return { age, gender } return { age, gender };
}) });
} }
public forwardInput(input: NetInput | tf.Tensor4D): NetOutput { public forwardInput(input: NetInput | tf.Tensor4D): NetOutput {
return tf.tidy(() => { return tf.tidy(() => {
const { age, gender } = this.runNet(input) const { age, gender } = this.runNet(input);
return { age, gender: tf.softmax(gender) } return { age, gender: tf.softmax(gender) };
}) });
} }
public async forward(input: TNetInput): Promise<NetOutput> { public async forward(input: TNetInput): Promise<NetOutput> {
return this.forwardInput(await toNetInput(input)) return this.forwardInput(await toNetInput(input));
} }
public async predictAgeAndGender(input: TNetInput): Promise<AgeAndGenderPrediction | AgeAndGenderPrediction[]> { public async predictAgeAndGender(input: TNetInput): Promise<AgeAndGenderPrediction | AgeAndGenderPrediction[]> {
const netInput = await toNetInput(input) const netInput = await toNetInput(input);
const out = await this.forwardInput(netInput) const out = await this.forwardInput(netInput);
const ages = tf.unstack(out.age) const ages = tf.unstack(out.age);
const genders = tf.unstack(out.gender) const genders = tf.unstack(out.gender);
const ageAndGenderTensors = ages.map((ageTensor, i) => ({ const ageAndGenderTensors = ages.map((ageTensor, i) => ({
ageTensor, ageTensor,
genderTensor: genders[i] genderTensor: genders[i],
})) }));
const predictionsByBatch = await Promise.all( const predictionsByBatch = await Promise.all(
ageAndGenderTensors.map(async ({ ageTensor, genderTensor }) => { ageAndGenderTensors.map(async ({ ageTensor, genderTensor }) => {
const age = (await ageTensor.data())[0] const age = (await ageTensor.data())[0];
const probMale = (await genderTensor.data())[0] const probMale = (await genderTensor.data())[0];
const isMale = probMale > 0.5 const isMale = probMale > 0.5;
const gender = isMale ? Gender.MALE : Gender.FEMALE const gender = isMale ? Gender.MALE : Gender.FEMALE;
const genderProbability = isMale ? probMale : (1 - probMale) const genderProbability = isMale ? probMale : (1 - probMale);
ageTensor.dispose() ageTensor.dispose();
genderTensor.dispose() genderTensor.dispose();
return { age, gender, genderProbability } return { age, gender, genderProbability };
}) }),
) );
out.age.dispose() out.age.dispose();
out.gender.dispose() out.gender.dispose();
return netInput.isBatchInput ? predictionsByBatch as AgeAndGenderPrediction[] : predictionsByBatch[0] as AgeAndGenderPrediction return netInput.isBatchInput ? predictionsByBatch as AgeAndGenderPrediction[] : predictionsByBatch[0] as AgeAndGenderPrediction;
} }
protected getDefaultModelName(): string { protected getDefaultModelName(): string {
return 'age_gender_model' return 'age_gender_model';
} }
public dispose(throwOnRedispose: boolean = true) { public dispose(throwOnRedispose: boolean = true) {
this.faceFeatureExtractor.dispose(throwOnRedispose) this.faceFeatureExtractor.dispose(throwOnRedispose);
super.dispose(throwOnRedispose) super.dispose(throwOnRedispose);
} }
public loadClassifierParams(weights: Float32Array) { public loadClassifierParams(weights: Float32Array) {
const { params, paramMappings } = this.extractClassifierParams(weights) const { params, paramMappings } = this.extractClassifierParams(weights);
this._params = params this._params = params;
this._paramMappings = paramMappings this._paramMappings = paramMappings;
} }
public extractClassifierParams(weights: Float32Array) { public extractClassifierParams(weights: Float32Array) {
return extractParams(weights) return extractParams(weights);
} }
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) { protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap);
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap) this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap);
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap) return extractParamsFromWeigthMap(classifierMap);
return extractParamsFromWeigthMap(classifierMap)
} }
protected extractParams(weights: Float32Array) { protected extractParams(weights: Float32Array) {
const classifierWeightSize = (512 * 1 + 1) + (512 * 2 + 2);
const classifierWeightSize = (512 * 1 + 1) + (512 * 2 + 2) const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize);
const classifierWeights = weights.slice(weights.length - classifierWeightSize);
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize) this.faceFeatureExtractor.extractWeights(featureExtractorWeights);
const classifierWeights = weights.slice(weights.length - classifierWeightSize) return this.extractClassifierParams(classifierWeights);
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
return this.extractClassifierParams(classifierWeights)
} }
} }

View File

@ -2,25 +2,24 @@ import { extractFCParamsFactory, extractWeightsFactory, ParamMapping } from '../
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } { export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [];
const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights,
} = extractWeightsFactory(weights) } = extractWeightsFactory(weights);
const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings) const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings);
const age = extractFCParams(512, 1, 'fc/age') const age = extractFCParams(512, 1, 'fc/age');
const gender = extractFCParams(512, 2, 'fc/gender') const gender = extractFCParams(512, 2, 'fc/gender');
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
} }
return { return {
paramMappings, paramMappings,
params: { fc: { age, gender } } params: { fc: { age, gender } },
} };
} }

View File

@ -1,30 +1,31 @@
import * as tf from '../../dist/tfjs.esm.js'; import * as tf from '../../dist/tfjs.esm.js';
import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index'; import {
disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping,
} from '../common/index';
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParamsFromWeigthMap( export function extractParamsFromWeigthMap(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap,
): { params: NetParams, paramMappings: ParamMapping[] } { ): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [];
const paramMappings: ParamMapping[] = [] const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
function extractFcParams(prefix: string): FCParams { function extractFcParams(prefix: string): FCParams {
const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2) const weights = extractWeightEntry(`${prefix}/weights`, 2);
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1) const bias = extractWeightEntry(`${prefix}/bias`, 1);
return { weights, bias } return { weights, bias };
} }
const params = { const params = {
fc: { fc: {
age: extractFcParams('fc/age'), age: extractFcParams('fc/age'),
gender: extractFcParams('fc/gender') gender: extractFcParams('fc/gender'),
} },
} };
disposeUnusedWeightTensors(weightMap, paramMappings) disposeUnusedWeightTensors(weightMap, paramMappings);
return { params, paramMappings } return { params, paramMappings };
} }

View File

@ -2,17 +2,20 @@ import * as tf from '../../dist/tfjs.esm.js';
import { FCParams } from '../common/index'; import { FCParams } from '../common/index';
// eslint-disable-next-line no-shadow
export enum Gender {
// eslint-disable-next-line no-unused-vars
FEMALE = 'female',
// eslint-disable-next-line no-unused-vars
MALE = 'male'
}
export type AgeAndGenderPrediction = { export type AgeAndGenderPrediction = {
age: number age: number
gender: Gender gender: Gender
genderProbability: number genderProbability: number
} }
export enum Gender {
FEMALE = 'female',
MALE = 'male'
}
export type NetOutput = { age: tf.Tensor1D, gender: tf.Tensor2D } export type NetOutput = { age: tf.Tensor1D, gender: tf.Tensor2D }
export type NetParams = { export type NetParams = {

View File

@ -9,6 +9,8 @@ export interface IBoundingBox {
export class BoundingBox extends Box<BoundingBox> implements IBoundingBox { export class BoundingBox extends Box<BoundingBox> implements IBoundingBox {
constructor(left: number, top: number, right: number, bottom: number, allowNegativeDimensions: boolean = false) { constructor(left: number, top: number, right: number, bottom: number, allowNegativeDimensions: boolean = false) {
super({ left, top, right, bottom }, allowNegativeDimensions) super({
left, top, right, bottom,
}, allowNegativeDimensions);
} }
} }

View File

@ -5,163 +5,197 @@ import { Point } from './Point';
import { IRect } from './Rect'; import { IRect } from './Rect';
export class Box<BoxType = any> implements IBoundingBox, IRect { export class Box<BoxType = any> implements IBoundingBox, IRect {
public static isRect(rect: any): boolean { public static isRect(rect: any): boolean {
return !!rect && [rect.x, rect.y, rect.width, rect.height].every(isValidNumber) return !!rect && [rect.x, rect.y, rect.width, rect.height].every(isValidNumber);
} }
public static assertIsValidBox(box: any, callee: string, allowNegativeDimensions: boolean = false) { public static assertIsValidBox(box: any, callee: string, allowNegativeDimensions: boolean = false) {
if (!Box.isRect(box)) { if (!Box.isRect(box)) {
throw new Error(`${callee} - invalid box: ${JSON.stringify(box)}, expected object with properties x, y, width, height`) throw new Error(`${callee} - invalid box: ${JSON.stringify(box)}, expected object with properties x, y, width, height`);
} }
if (!allowNegativeDimensions && (box.width < 0 || box.height < 0)) { if (!allowNegativeDimensions && (box.width < 0 || box.height < 0)) {
throw new Error(`${callee} - width (${box.width}) and height (${box.height}) must be positive numbers`) throw new Error(`${callee} - width (${box.width}) and height (${box.height}) must be positive numbers`);
} }
} }
private _x: number private _x: number
private _y: number private _y: number
private _width: number private _width: number
private _height: number private _height: number
constructor(_box: IBoundingBox | IRect, allowNegativeDimensions: boolean = true) { constructor(_box: IBoundingBox | IRect, allowNegativeDimensions: boolean = true) {
const box = (_box || {}) as any const box = (_box || {}) as any;
const isBbox = [box.left, box.top, box.right, box.bottom].every(isValidNumber) const isBbox = [box.left, box.top, box.right, box.bottom].every(isValidNumber);
const isRect = [box.x, box.y, box.width, box.height].every(isValidNumber) const isRect = [box.x, box.y, box.width, box.height].every(isValidNumber);
if (!isRect && !isBbox) { if (!isRect && !isBbox) {
throw new Error(`Box.constructor - expected box to be IBoundingBox | IRect, instead have ${JSON.stringify(box)}`) throw new Error(`Box.constructor - expected box to be IBoundingBox | IRect, instead have ${JSON.stringify(box)}`);
} }
const [x, y, width, height] = isRect const [x, y, width, height] = isRect
? [box.x, box.y, box.width, box.height] ? [box.x, box.y, box.width, box.height]
: [box.left, box.top, box.right - box.left, box.bottom - box.top] : [box.left, box.top, box.right - box.left, box.bottom - box.top];
Box.assertIsValidBox({ x, y, width, height }, 'Box.constructor', allowNegativeDimensions) Box.assertIsValidBox({
x, y, width, height,
}, 'Box.constructor', allowNegativeDimensions);
this._x = x this._x = x;
this._y = y this._y = y;
this._width = width this._width = width;
this._height = height this._height = height;
} }
public get x(): number { return this._x } public get x(): number { return this._x; }
public get y(): number { return this._y }
public get width(): number { return this._width } public get y(): number { return this._y; }
public get height(): number { return this._height }
public get left(): number { return this.x } public get width(): number { return this._width; }
public get top(): number { return this.y }
public get right(): number { return this.x + this.width } public get height(): number { return this._height; }
public get bottom(): number { return this.y + this.height }
public get area(): number { return this.width * this.height } public get left(): number { return this.x; }
public get topLeft(): Point { return new Point(this.left, this.top) }
public get topRight(): Point { return new Point(this.right, this.top) } public get top(): number { return this.y; }
public get bottomLeft(): Point { return new Point(this.left, this.bottom) }
public get bottomRight(): Point { return new Point(this.right, this.bottom) } public get right(): number { return this.x + this.width; }
public get bottom(): number { return this.y + this.height; }
public get area(): number { return this.width * this.height; }
public get topLeft(): Point { return new Point(this.left, this.top); }
public get topRight(): Point { return new Point(this.right, this.top); }
public get bottomLeft(): Point { return new Point(this.left, this.bottom); }
public get bottomRight(): Point { return new Point(this.right, this.bottom); }
public round(): Box<BoxType> { public round(): Box<BoxType> {
const [x, y, width, height] = [this.x, this.y, this.width, this.height] const [x, y, width, height] = [this.x, this.y, this.width, this.height]
.map(val => Math.round(val)) .map((val) => Math.round(val));
return new Box({ x, y, width, height }) return new Box({
x, y, width, height,
});
} }
public floor(): Box<BoxType> { public floor(): Box<BoxType> {
const [x, y, width, height] = [this.x, this.y, this.width, this.height] const [x, y, width, height] = [this.x, this.y, this.width, this.height]
.map(val => Math.floor(val)) .map((val) => Math.floor(val));
return new Box({ x, y, width, height }) return new Box({
x, y, width, height,
});
} }
public toSquare(): Box<BoxType> { public toSquare(): Box<BoxType> {
let { x, y, width, height } = this let {
const diff = Math.abs(width - height) x, y, width, height,
} = this;
const diff = Math.abs(width - height);
if (width < height) { if (width < height) {
x -= (diff / 2) x -= (diff / 2);
width += diff width += diff;
} }
if (height < width) { if (height < width) {
y -= (diff / 2) y -= (diff / 2);
height += diff height += diff;
} }
return new Box({ x, y, width, height }) return new Box({
x, y, width, height,
});
} }
public rescale(s: IDimensions | number): Box<BoxType> { public rescale(s: IDimensions | number): Box<BoxType> {
const scaleX = isDimensions(s) ? (s as IDimensions).width : s as number const scaleX = isDimensions(s) ? (s as IDimensions).width : s as number;
const scaleY = isDimensions(s) ? (s as IDimensions).height : s as number const scaleY = isDimensions(s) ? (s as IDimensions).height : s as number;
return new Box({ return new Box({
x: this.x * scaleX, x: this.x * scaleX,
y: this.y * scaleY, y: this.y * scaleY,
width: this.width * scaleX, width: this.width * scaleX,
height: this.height * scaleY height: this.height * scaleY,
}) });
} }
public pad(padX: number, padY: number): Box<BoxType> { public pad(padX: number, padY: number): Box<BoxType> {
let [x, y, width, height] = [ const [x, y, width, height] = [
this.x - (padX / 2), this.x - (padX / 2),
this.y - (padY / 2), this.y - (padY / 2),
this.width + padX, this.width + padX,
this.height + padY this.height + padY,
] ];
return new Box({ x, y, width, height }) return new Box({
x, y, width, height,
});
} }
public clipAtImageBorders(imgWidth: number, imgHeight: number): Box<BoxType> { public clipAtImageBorders(imgWidth: number, imgHeight: number): Box<BoxType> {
const { x, y, right, bottom } = this const {
const clippedX = Math.max(x, 0) x, y, right, bottom,
const clippedY = Math.max(y, 0) } = this;
const clippedX = Math.max(x, 0);
const clippedY = Math.max(y, 0);
const newWidth = right - clippedX const newWidth = right - clippedX;
const newHeight = bottom - clippedY const newHeight = bottom - clippedY;
const clippedWidth = Math.min(newWidth, imgWidth - clippedX) const clippedWidth = Math.min(newWidth, imgWidth - clippedX);
const clippedHeight = Math.min(newHeight, imgHeight - clippedY) const clippedHeight = Math.min(newHeight, imgHeight - clippedY);
return (new Box({ x: clippedX, y: clippedY, width: clippedWidth, height: clippedHeight})).floor() return (new Box({
x: clippedX, y: clippedY, width: clippedWidth, height: clippedHeight,
})).floor();
} }
public shift(sx: number, sy: number): Box<BoxType> { public shift(sx: number, sy: number): Box<BoxType> {
const { width, height } = this const { width, height } = this;
const x = this.x + sx const x = this.x + sx;
const y = this.y + sy const y = this.y + sy;
return new Box({ x, y, width, height }) return new Box({
x, y, width, height,
});
} }
public padAtBorders(imageHeight: number, imageWidth: number) { public padAtBorders(imageHeight: number, imageWidth: number) {
const w = this.width + 1 const w = this.width + 1;
const h = this.height + 1 const h = this.height + 1;
let dx = 1 const dx = 1;
let dy = 1 const dy = 1;
let edx = w let edx = w;
let edy = h let edy = h;
let x = this.left let x = this.left;
let y = this.top let y = this.top;
let ex = this.right let ex = this.right;
let ey = this.bottom let ey = this.bottom;
if (ex > imageWidth) { if (ex > imageWidth) {
edx = -ex + imageWidth + w edx = -ex + imageWidth + w;
ex = imageWidth ex = imageWidth;
} }
if (ey > imageHeight) { if (ey > imageHeight) {
edy = -ey + imageHeight + h edy = -ey + imageHeight + h;
ey = imageHeight ey = imageHeight;
} }
if (x < 1) { if (x < 1) {
edy = 2 - x edy = 2 - x;
x = 1 x = 1;
} }
if (y < 1) { if (y < 1) {
edy = 2 - y edy = 2 - y;
y = 1 y = 1;
} }
return { dy, edy, dx, edx, y, ey, x, ex, w, h } return {
dy, edy, dx, edx, y, ey, x, ex, w, h,
};
} }
public calibrate(region: Box) { public calibrate(region: Box) {
@ -169,7 +203,7 @@ export class Box<BoxType = any> implements IBoundingBox, IRect {
left: this.left + (region.left * this.width), left: this.left + (region.left * this.width),
top: this.top + (region.top * this.height), top: this.top + (region.top * this.height),
right: this.right + (region.right * this.width), right: this.right + (region.right * this.width),
bottom: this.bottom + (region.bottom * this.height) bottom: this.bottom + (region.bottom * this.height),
}).toSquare().round() }).toSquare().round();
} }
} }

View File

@ -6,23 +6,24 @@ export interface IDimensions {
} }
export class Dimensions implements IDimensions { export class Dimensions implements IDimensions {
private _width: number private _width: number
private _height: number private _height: number
constructor(width: number, height: number) { constructor(width: number, height: number) {
if (!isValidNumber(width) || !isValidNumber(height)) { if (!isValidNumber(width) || !isValidNumber(height)) {
throw new Error(`Dimensions.constructor - expected width and height to be valid numbers, instead have ${JSON.stringify({ width, height })}`) throw new Error(`Dimensions.constructor - expected width and height to be valid numbers, instead have ${JSON.stringify({ width, height })}`);
} }
this._width = width this._width = width;
this._height = height this._height = height;
} }
public get width(): number { return this._width } public get width(): number { return this._width; }
public get height(): number { return this._height }
public get height(): number { return this._height; }
public reverse(): Dimensions { public reverse(): Dimensions {
return new Dimensions(1 / this.width, 1 / this.height) return new Dimensions(1 / this.width, 1 / this.height);
} }
} }

View File

@ -12,13 +12,13 @@ export class FaceDetection extends ObjectDetection implements IFaceDetecion {
constructor( constructor(
score: number, score: number,
relativeBox: Rect, relativeBox: Rect,
imageDims: IDimensions imageDims: IDimensions,
) { ) {
super(score, score, '', relativeBox, imageDims) super(score, score, '', relativeBox, imageDims);
} }
public forSize(width: number, height: number): FaceDetection { public forSize(width: number, height: number): FaceDetection {
const { score, relativeBox, imageDims } = super.forSize(width, height) const { score, relativeBox, imageDims } = super.forSize(width, height);
return new FaceDetection(score, relativeBox, imageDims) return new FaceDetection(score, relativeBox, imageDims);
} }
} }

View File

@ -8,9 +8,9 @@ import { Point } from './Point';
import { IRect, Rect } from './Rect'; import { IRect, Rect } from './Rect';
// face alignment constants // face alignment constants
const relX = 0.5 const relX = 0.5;
const relY = 0.43 const relY = 0.43;
const relScale = 0.45 const relScale = 0.45;
export interface IFaceLandmarks { export interface IFaceLandmarks {
positions: Point[] positions: Point[]
@ -19,49 +19,55 @@ export interface IFaceLandmarks {
export class FaceLandmarks implements IFaceLandmarks { export class FaceLandmarks implements IFaceLandmarks {
protected _shift: Point protected _shift: Point
protected _positions: Point[] protected _positions: Point[]
protected _imgDims: Dimensions protected _imgDims: Dimensions
constructor( constructor(
relativeFaceLandmarkPositions: Point[], relativeFaceLandmarkPositions: Point[],
imgDims: IDimensions, imgDims: IDimensions,
shift: Point = new Point(0, 0) shift: Point = new Point(0, 0),
) { ) {
const { width, height } = imgDims const { width, height } = imgDims;
this._imgDims = new Dimensions(width, height) this._imgDims = new Dimensions(width, height);
this._shift = shift this._shift = shift;
this._positions = relativeFaceLandmarkPositions.map( this._positions = relativeFaceLandmarkPositions.map(
pt => pt.mul(new Point(width, height)).add(shift) (pt) => pt.mul(new Point(width, height)).add(shift),
) );
} }
public get shift(): Point { return new Point(this._shift.x, this._shift.y) } public get shift(): Point { return new Point(this._shift.x, this._shift.y); }
public get imageWidth(): number { return this._imgDims.width }
public get imageHeight(): number { return this._imgDims.height } public get imageWidth(): number { return this._imgDims.width; }
public get positions(): Point[] { return this._positions }
public get imageHeight(): number { return this._imgDims.height; }
public get positions(): Point[] { return this._positions; }
public get relativePositions(): Point[] { public get relativePositions(): Point[] {
return this._positions.map( return this._positions.map(
pt => pt.sub(this._shift).div(new Point(this.imageWidth, this.imageHeight)) (pt) => pt.sub(this._shift).div(new Point(this.imageWidth, this.imageHeight)),
) );
} }
public forSize<T extends FaceLandmarks>(width: number, height: number): T { public forSize<T extends FaceLandmarks>(width: number, height: number): T {
return new (this.constructor as any)( return new (this.constructor as any)(
this.relativePositions, this.relativePositions,
{ width, height } { width, height },
) );
} }
public shiftBy<T extends FaceLandmarks>(x: number, y: number): T { public shiftBy<T extends FaceLandmarks>(x: number, y: number): T {
return new (this.constructor as any)( return new (this.constructor as any)(
this.relativePositions, this.relativePositions,
this._imgDims, this._imgDims,
new Point(x, y) new Point(x, y),
) );
} }
public shiftByPoint<T extends FaceLandmarks>(pt: Point): T { public shiftByPoint<T extends FaceLandmarks>(pt: Point): T {
return this.shiftBy(pt.x, pt.y) return this.shiftBy(pt.x, pt.y);
} }
/** /**
@ -77,49 +83,48 @@ export class FaceLandmarks implements IFaceLandmarks {
*/ */
public align( public align(
detection?: FaceDetection | IRect | IBoundingBox | null, detection?: FaceDetection | IRect | IBoundingBox | null,
options: { useDlibAlignment?: boolean, minBoxPadding?: number } = { } options: { useDlibAlignment?: boolean, minBoxPadding?: number } = { },
): Box { ): Box {
if (detection) { if (detection) {
const box = detection instanceof FaceDetection const box = detection instanceof FaceDetection
? detection.box.floor() ? detection.box.floor()
: new Box(detection) : new Box(detection);
return this.shiftBy(box.x, box.y).align(null, options) return this.shiftBy(box.x, box.y).align(null, options);
} }
const { useDlibAlignment, minBoxPadding } = Object.assign({}, { useDlibAlignment: false, minBoxPadding: 0.2 }, options) const { useDlibAlignment, minBoxPadding } = { useDlibAlignment: false, minBoxPadding: 0.2, ...options };
if (useDlibAlignment) { if (useDlibAlignment) {
return this.alignDlib() return this.alignDlib();
} }
return this.alignMinBbox(minBoxPadding) return this.alignMinBbox(minBoxPadding);
} }
private alignDlib(): Box { private alignDlib(): Box {
const centers = this.getRefPointsForAlignment();
const centers = this.getRefPointsForAlignment() const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers;
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude();
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2;
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers const size = Math.floor(eyeToMouthDist / relScale);
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
const size = Math.floor(eyeToMouthDist / relScale) const refPoint = getCenterPoint(centers);
const refPoint = getCenterPoint(centers)
// TODO: pad in case rectangle is out of image bounds // TODO: pad in case rectangle is out of image bounds
const x = Math.floor(Math.max(0, refPoint.x - (relX * size))) const x = Math.floor(Math.max(0, refPoint.x - (relX * size)));
const y = Math.floor(Math.max(0, refPoint.y - (relY * size))) const y = Math.floor(Math.max(0, refPoint.y - (relY * size)));
return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y)) return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y));
} }
private alignMinBbox(padding: number): Box { private alignMinBbox(padding: number): Box {
const box = minBbox(this.positions) const box = minBbox(this.positions);
return box.pad(box.width * padding, box.height * padding) return box.pad(box.width * padding, box.height * padding);
} }
protected getRefPointsForAlignment(): Point[] { protected getRefPointsForAlignment(): Point[] {
throw new Error('getRefPointsForAlignment not implemented by base class') throw new Error('getRefPointsForAlignment not implemented by base class');
} }
} }

View File

@ -2,15 +2,13 @@ import { getCenterPoint } from '../utils/index';
import { FaceLandmarks } from './FaceLandmarks'; import { FaceLandmarks } from './FaceLandmarks';
import { Point } from './Point'; import { Point } from './Point';
export class FaceLandmarks5 extends FaceLandmarks { export class FaceLandmarks5 extends FaceLandmarks {
protected getRefPointsForAlignment(): Point[] { protected getRefPointsForAlignment(): Point[] {
const pts = this.positions const pts = this.positions;
return [ return [
pts[0], pts[0],
pts[1], pts[1],
getCenterPoint([pts[3], pts[4]]) getCenterPoint([pts[3], pts[4]]),
] ];
} }
} }

View File

@ -4,38 +4,38 @@ import { Point } from './Point';
export class FaceLandmarks68 extends FaceLandmarks { export class FaceLandmarks68 extends FaceLandmarks {
public getJawOutline(): Point[] { public getJawOutline(): Point[] {
return this.positions.slice(0, 17) return this.positions.slice(0, 17);
} }
public getLeftEyeBrow(): Point[] { public getLeftEyeBrow(): Point[] {
return this.positions.slice(17, 22) return this.positions.slice(17, 22);
} }
public getRightEyeBrow(): Point[] { public getRightEyeBrow(): Point[] {
return this.positions.slice(22, 27) return this.positions.slice(22, 27);
} }
public getNose(): Point[] { public getNose(): Point[] {
return this.positions.slice(27, 36) return this.positions.slice(27, 36);
} }
public getLeftEye(): Point[] { public getLeftEye(): Point[] {
return this.positions.slice(36, 42) return this.positions.slice(36, 42);
} }
public getRightEye(): Point[] { public getRightEye(): Point[] {
return this.positions.slice(42, 48) return this.positions.slice(42, 48);
} }
public getMouth(): Point[] { public getMouth(): Point[] {
return this.positions.slice(48, 68) return this.positions.slice(48, 68);
} }
protected getRefPointsForAlignment(): Point[] { protected getRefPointsForAlignment(): Point[] {
return [ return [
this.getLeftEye(), this.getLeftEye(),
this.getRightEye(), this.getRightEye(),
this.getMouth() this.getMouth(),
].map(getCenterPoint) ].map(getCenterPoint);
} }
} }

View File

@ -7,17 +7,19 @@ export interface IFaceMatch {
export class FaceMatch implements IFaceMatch { export class FaceMatch implements IFaceMatch {
private _label: string private _label: string
private _distance: number private _distance: number
constructor(label: string, distance: number) { constructor(label: string, distance: number) {
this._label = label this._label = label;
this._distance = distance this._distance = distance;
} }
public get label(): string { return this._label } public get label(): string { return this._label; }
public get distance(): number { return this._distance }
public get distance(): number { return this._distance; }
public toString(withDistance: boolean = true): string { public toString(withDistance: boolean = true): string {
return `${this.label}${withDistance ? ` (${round(this.distance)})` : ''}` return `${this.label}${withDistance ? ` (${round(this.distance)})` : ''}`;
} }
} }

View File

@ -4,22 +4,20 @@ import { Box } from './Box';
import { IRect } from './Rect'; import { IRect } from './Rect';
export class LabeledBox extends Box<LabeledBox> { export class LabeledBox extends Box<LabeledBox> {
public static assertIsValidLabeledBox(box: any, callee: string) { public static assertIsValidLabeledBox(box: any, callee: string) {
Box.assertIsValidBox(box, callee) Box.assertIsValidBox(box, callee);
if (!isValidNumber(box.label)) { if (!isValidNumber(box.label)) {
throw new Error(`${callee} - expected property label (${box.label}) to be a number`) throw new Error(`${callee} - expected property label (${box.label}) to be a number`);
} }
} }
private _label: number private _label: number
constructor(box: IBoundingBox | IRect | any, label: number) { constructor(box: IBoundingBox | IRect | any, label: number) {
super(box) super(box);
this._label = label this._label = label;
} }
public get label(): number { return this._label } public get label(): number { return this._label; }
} }

View File

@ -1,35 +1,34 @@
export class LabeledFaceDescriptors { export class LabeledFaceDescriptors {
private _label: string private _label: string
private _descriptors: Float32Array[] private _descriptors: Float32Array[]
constructor(label: string, descriptors: Float32Array[]) { constructor(label: string, descriptors: Float32Array[]) {
if (!(typeof label === 'string')) { if (!(typeof label === 'string')) {
throw new Error('LabeledFaceDescriptors - constructor expected label to be a string') throw new Error('LabeledFaceDescriptors - constructor expected label to be a string');
} }
if (!Array.isArray(descriptors) || descriptors.some(desc => !(desc instanceof Float32Array))) { if (!Array.isArray(descriptors) || descriptors.some((desc) => !(desc instanceof Float32Array))) {
throw new Error('LabeledFaceDescriptors - constructor expected descriptors to be an array of Float32Array') throw new Error('LabeledFaceDescriptors - constructor expected descriptors to be an array of Float32Array');
} }
this._label = label this._label = label;
this._descriptors = descriptors this._descriptors = descriptors;
} }
public get label(): string { return this._label } public get label(): string { return this._label; }
public get descriptors(): Float32Array[] { return this._descriptors }
public get descriptors(): Float32Array[] { return this._descriptors; }
public toJSON(): any { public toJSON(): any {
return { return {
label: this.label, label: this.label,
descriptors: this.descriptors.map((d) => Array.from(d)) descriptors: this.descriptors.map((d) => Array.from(d)),
}; };
} }
public static fromJSON(json: any): LabeledFaceDescriptors { public static fromJSON(json: any): LabeledFaceDescriptors {
const descriptors = json.descriptors.map((d: any) => { const descriptors = json.descriptors.map((d: any) => new Float32Array(d));
return new Float32Array(d);
});
return new LabeledFaceDescriptors(json.label, descriptors); return new LabeledFaceDescriptors(json.label, descriptors);
} }
} }

View File

@ -4,9 +4,13 @@ import { IRect, Rect } from './Rect';
export class ObjectDetection { export class ObjectDetection {
private _score: number private _score: number
private _classScore: number private _classScore: number
private _className: string private _className: string
private _box: Rect private _box: Rect
private _imageDims: Dimensions private _imageDims: Dimensions
constructor( constructor(
@ -14,23 +18,30 @@ export class ObjectDetection {
classScore: number, classScore: number,
className: string, className: string,
relativeBox: IRect, relativeBox: IRect,
imageDims: IDimensions imageDims: IDimensions,
) { ) {
this._imageDims = new Dimensions(imageDims.width, imageDims.height) this._imageDims = new Dimensions(imageDims.width, imageDims.height);
this._score = score this._score = score;
this._classScore = classScore this._classScore = classScore;
this._className = className this._className = className;
this._box = new Box(relativeBox).rescale(this._imageDims) this._box = new Box(relativeBox).rescale(this._imageDims);
} }
public get score(): number { return this._score } public get score(): number { return this._score; }
public get classScore(): number { return this._classScore }
public get className(): string { return this._className } public get classScore(): number { return this._classScore; }
public get box(): Box { return this._box }
public get imageDims(): Dimensions { return this._imageDims } public get className(): string { return this._className; }
public get imageWidth(): number { return this.imageDims.width }
public get imageHeight(): number { return this.imageDims.height } public get box(): Box { return this._box; }
public get relativeBox(): Box { return new Box(this._box).rescale(this.imageDims.reverse()) }
public get imageDims(): Dimensions { return this._imageDims; }
public get imageWidth(): number { return this.imageDims.width; }
public get imageHeight(): number { return this.imageDims.height; }
public get relativeBox(): Box { return new Box(this._box).rescale(this.imageDims.reverse()); }
public forSize(width: number, height: number): ObjectDetection { public forSize(width: number, height: number): ObjectDetection {
return new ObjectDetection( return new ObjectDetection(
@ -38,7 +49,7 @@ export class ObjectDetection {
this.classScore, this.classScore,
this.className, this.className,
this.relativeBox, this.relativeBox,
{ width, height} { width, height },
) );
} }
} }

View File

@ -5,41 +5,43 @@ export interface IPoint {
export class Point implements IPoint { export class Point implements IPoint {
private _x: number private _x: number
private _y: number private _y: number
constructor(x: number, y: number) { constructor(x: number, y: number) {
this._x = x this._x = x;
this._y = y this._y = y;
} }
get x(): number { return this._x } get x(): number { return this._x; }
get y(): number { return this._y }
get y(): number { return this._y; }
public add(pt: IPoint): Point { public add(pt: IPoint): Point {
return new Point(this.x + pt.x, this.y + pt.y) return new Point(this.x + pt.x, this.y + pt.y);
} }
public sub(pt: IPoint): Point { public sub(pt: IPoint): Point {
return new Point(this.x - pt.x, this.y - pt.y) return new Point(this.x - pt.x, this.y - pt.y);
} }
public mul(pt: IPoint): Point { public mul(pt: IPoint): Point {
return new Point(this.x * pt.x, this.y * pt.y) return new Point(this.x * pt.x, this.y * pt.y);
} }
public div(pt: IPoint): Point { public div(pt: IPoint): Point {
return new Point(this.x / pt.x, this.y / pt.y) return new Point(this.x / pt.x, this.y / pt.y);
} }
public abs(): Point { public abs(): Point {
return new Point(Math.abs(this.x), Math.abs(this.y)) return new Point(Math.abs(this.x), Math.abs(this.y));
} }
public magnitude(): number { public magnitude(): number {
return Math.sqrt(Math.pow(this.x, 2) + Math.pow(this.y, 2)) return Math.sqrt((this.x ** 2) + (this.y ** 2));
} }
public floor(): Point { public floor(): Point {
return new Point(Math.floor(this.x), Math.floor(this.y)) return new Point(Math.floor(this.x), Math.floor(this.y));
} }
} }

View File

@ -4,28 +4,28 @@ import { LabeledBox } from './LabeledBox';
import { IRect } from './Rect'; import { IRect } from './Rect';
export class PredictedBox extends LabeledBox { export class PredictedBox extends LabeledBox {
public static assertIsValidPredictedBox(box: any, callee: string) { public static assertIsValidPredictedBox(box: any, callee: string) {
LabeledBox.assertIsValidLabeledBox(box, callee) LabeledBox.assertIsValidLabeledBox(box, callee);
if ( if (
!isValidProbablitiy(box.score) !isValidProbablitiy(box.score)
|| !isValidProbablitiy(box.classScore) || !isValidProbablitiy(box.classScore)
) { ) {
throw new Error(`${callee} - expected properties score (${box.score}) and (${box.classScore}) to be a number between [0, 1]`) throw new Error(`${callee} - expected properties score (${box.score}) and (${box.classScore}) to be a number between [0, 1]`);
} }
} }
private _score: number private _score: number
private _classScore: number private _classScore: number
constructor(box: IBoundingBox | IRect | any, label: number, score: number, classScore: number) { constructor(box: IBoundingBox | IRect | any, label: number, score: number, classScore: number) {
super(box, label) super(box, label);
this._score = score this._score = score;
this._classScore = classScore this._classScore = classScore;
} }
public get score(): number { return this._score } public get score(): number { return this._score; }
public get classScore(): number { return this._classScore }
public get classScore(): number { return this._classScore; }
} }

View File

@ -9,6 +9,8 @@ export interface IRect {
export class Rect extends Box<Rect> implements IRect { export class Rect extends Box<Rect> implements IRect {
constructor(x: number, y: number, width: number, height: number, allowNegativeDimensions: boolean = false) { constructor(x: number, y: number, width: number, height: number, allowNegativeDimensions: boolean = false) {
super({ x, y, width, height }, allowNegativeDimensions) super({
x, y, width, height,
}, allowNegativeDimensions);
} }
} }

View File

@ -1,14 +1,14 @@
export * from './BoundingBox' export * from './BoundingBox';
export * from './Box' export * from './Box';
export * from './Dimensions' export * from './Dimensions';
export * from './FaceDetection'; export * from './FaceDetection';
export * from './FaceLandmarks'; export * from './FaceLandmarks';
export * from './FaceLandmarks5'; export * from './FaceLandmarks5';
export * from './FaceLandmarks68'; export * from './FaceLandmarks68';
export * from './FaceMatch'; export * from './FaceMatch';
export * from './LabeledBox' export * from './LabeledBox';
export * from './LabeledFaceDescriptors'; export * from './LabeledFaceDescriptors';
export * from './ObjectDetection' export * from './ObjectDetection';
export * from './Point' export * from './Point';
export * from './PredictedBox' export * from './PredictedBox';
export * from './Rect' export * from './Rect';

View File

@ -6,14 +6,14 @@ export function convLayer(
x: tf.Tensor4D, x: tf.Tensor4D,
params: ConvParams, params: ConvParams,
padding: 'valid' | 'same' = 'same', padding: 'valid' | 'same' = 'same',
withRelu: boolean = false withRelu: boolean = false,
): tf.Tensor4D { ): tf.Tensor4D {
return tf.tidy(() => { return tf.tidy(() => {
const out = tf.add( const out = tf.add(
tf.conv2d(x, params.filters, [1, 1], padding), tf.conv2d(x, params.filters, [1, 1], padding),
params.bias params.bias,
) as tf.Tensor4D ) as tf.Tensor4D;
return withRelu ? tf.relu(out) : out return withRelu ? tf.relu(out) : out;
}) });
} }

View File

@ -5,11 +5,11 @@ import { SeparableConvParams } from './types';
export function depthwiseSeparableConv( export function depthwiseSeparableConv(
x: tf.Tensor4D, x: tf.Tensor4D,
params: SeparableConvParams, params: SeparableConvParams,
stride: [number, number] stride: [number, number],
): tf.Tensor4D { ): tf.Tensor4D {
return tf.tidy(() => { return tf.tidy(() => {
let out = tf.separableConv2d(x, params.depthwise_filter, params.pointwise_filter, stride, 'same') let out = tf.separableConv2d(x, params.depthwise_filter, params.pointwise_filter, stride, 'same');
out = tf.add(out, params.bias) out = tf.add(out, params.bias);
return out return out;
}) });
} }

View File

@ -1,9 +1,9 @@
import { ParamMapping } from './types'; import { ParamMapping } from './types';
export function disposeUnusedWeightTensors(weightMap: any, paramMappings: ParamMapping[]) { export function disposeUnusedWeightTensors(weightMap: any, paramMappings: ParamMapping[]) {
Object.keys(weightMap).forEach(path => { Object.keys(weightMap).forEach((path) => {
if (!paramMappings.some(pm => pm.originalPath === path)) { if (!paramMappings.some((pm) => pm.originalPath === path)) {
weightMap[path].dispose() weightMap[path].dispose();
} }
}) });
} }

View File

@ -4,28 +4,25 @@ import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
export function extractConvParamsFactory( export function extractConvParamsFactory(
extractWeights: ExtractWeightsFunction, extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[] paramMappings: ParamMapping[],
) { ) {
return (
return function(
channelsIn: number, channelsIn: number,
channelsOut: number, channelsOut: number,
filterSize: number, filterSize: number,
mappedPrefix: string mappedPrefix: string,
): ConvParams { ): ConvParams => {
const filters = tf.tensor4d( const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize), extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut] [filterSize, filterSize, channelsIn, channelsOut],
) );
const bias = tf.tensor1d(extractWeights(channelsOut)) const bias = tf.tensor1d(extractWeights(channelsOut));
paramMappings.push( paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` }, { paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` } { paramPath: `${mappedPrefix}/bias` },
) );
return { filters, bias }
}
return { filters, bias };
};
} }

View File

@ -2,30 +2,26 @@ import * as tf from '../../dist/tfjs.esm.js';
import { ExtractWeightsFunction, FCParams, ParamMapping } from './types'; import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
export function extractFCParamsFactory( export function extractFCParamsFactory(
extractWeights: ExtractWeightsFunction, extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[] paramMappings: ParamMapping[],
) { ) {
return (
return function(
channelsIn: number, channelsIn: number,
channelsOut: number, channelsOut: number,
mappedPrefix: string mappedPrefix: string,
): FCParams { ): FCParams => {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]);
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]) const fc_bias = tf.tensor1d(extractWeights(channelsOut));
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push( paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` }, { paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/bias` } { paramPath: `${mappedPrefix}/bias` },
) );
return { return {
weights: fc_weights, weights: fc_weights,
bias: fc_bias bias: fc_bias,
} };
} };
} }

View File

@ -4,43 +4,40 @@ import { ExtractWeightsFunction, ParamMapping, SeparableConvParams } from './typ
export function extractSeparableConvParamsFactory( export function extractSeparableConvParamsFactory(
extractWeights: ExtractWeightsFunction, extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[] paramMappings: ParamMapping[],
) { ) {
return (channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams => {
return function(channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams { const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]);
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]) const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]);
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]) const bias = tf.tensor1d(extractWeights(channelsOut));
const bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push( paramMappings.push(
{ paramPath: `${mappedPrefix}/depthwise_filter` }, { paramPath: `${mappedPrefix}/depthwise_filter` },
{ paramPath: `${mappedPrefix}/pointwise_filter` }, { paramPath: `${mappedPrefix}/pointwise_filter` },
{ paramPath: `${mappedPrefix}/bias` } { paramPath: `${mappedPrefix}/bias` },
) );
return new SeparableConvParams( return new SeparableConvParams(
depthwise_filter, depthwise_filter,
pointwise_filter, pointwise_filter,
bias bias,
) );
} };
} }
export function loadSeparableConvParamsFactory( export function loadSeparableConvParamsFactory(
extractWeightEntry: <T>(originalPath: string, paramRank: number) => T // eslint-disable-next-line no-unused-vars
extractWeightEntry: <T>(originalPath: string, paramRank: number) => T,
) { ) {
return (prefix: string): SeparableConvParams => {
return function (prefix: string): SeparableConvParams { const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4);
const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4) const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4);
const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return new SeparableConvParams( return new SeparableConvParams(
depthwise_filter, depthwise_filter,
pointwise_filter, pointwise_filter,
bias bias,
) );
} };
} }

View File

@ -2,19 +2,17 @@ import { isTensor } from '../utils/index';
import { ParamMapping } from './types'; import { ParamMapping } from './types';
export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) { export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) {
return (originalPath: string, paramRank: number, mappedPath?: string) => {
return function<T> (originalPath: string, paramRank: number, mappedPath?: string): T { const tensor = weightMap[originalPath];
const tensor = weightMap[originalPath]
if (!isTensor(tensor, paramRank)) { if (!isTensor(tensor, paramRank)) {
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`) throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`);
} }
paramMappings.push( paramMappings.push(
{ originalPath, paramPath: mappedPath || originalPath } { originalPath, paramPath: mappedPath || originalPath },
) );
return tensor
}
return tensor;
};
} }

View File

@ -1,18 +1,18 @@
export function extractWeightsFactory(weights: Float32Array) { export function extractWeightsFactory(weights: Float32Array) {
let remainingWeights = weights let remainingWeights = weights;
function extractWeights(numWeights: number): Float32Array { function extractWeights(numWeights: number): Float32Array {
const ret = remainingWeights.slice(0, numWeights) const ret = remainingWeights.slice(0, numWeights);
remainingWeights = remainingWeights.slice(numWeights) remainingWeights = remainingWeights.slice(numWeights);
return ret return ret;
} }
function getRemainingWeights(): Float32Array { function getRemainingWeights(): Float32Array {
return remainingWeights return remainingWeights;
} }
return { return {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights,
} };
} }

View File

@ -4,12 +4,10 @@ import { FCParams } from './types';
export function fullyConnectedLayer( export function fullyConnectedLayer(
x: tf.Tensor2D, x: tf.Tensor2D,
params: FCParams params: FCParams,
): tf.Tensor2D { ): tf.Tensor2D {
return tf.tidy(() => return tf.tidy(() => tf.add(
tf.add( tf.matMul(x, params.weights),
tf.matMul(x, params.weights), params.bias,
params.bias ));
)
)
} }

View File

@ -1,33 +1,34 @@
export function getModelUris(uri: string | undefined, defaultModelName: string) { export function getModelUris(uri: string | undefined, defaultModelName: string) {
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json` const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`;
if (!uri) { if (!uri) {
return { return {
modelBaseUri: '', modelBaseUri: '',
manifestUri: defaultManifestFilename manifestUri: defaultManifestFilename,
} };
} }
if (uri === '/') { if (uri === '/') {
return { return {
modelBaseUri: '/', modelBaseUri: '/',
manifestUri: `/${defaultManifestFilename}` manifestUri: `/${defaultManifestFilename}`,
} };
} }
// eslint-disable-next-line no-nested-ternary
const protocol = uri.startsWith('http://') ? 'http://' : uri.startsWith('https://') ? 'https://' : ''; const protocol = uri.startsWith('http://') ? 'http://' : uri.startsWith('https://') ? 'https://' : '';
uri = uri.replace(protocol, ''); uri = uri.replace(protocol, '');
const parts = uri.split('/').filter(s => s) const parts = uri.split('/').filter((s) => s);
const manifestFile = uri.endsWith('.json') const manifestFile = uri.endsWith('.json')
? parts[parts.length - 1] ? parts[parts.length - 1]
: defaultManifestFilename : defaultManifestFilename;
let modelBaseUri = protocol + (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/') let modelBaseUri = protocol + (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/');
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri;
return { return {
modelBaseUri, modelBaseUri,
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}` manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`,
} };
} }

View File

@ -1,10 +1,10 @@
export * from './convLayer' export * from './convLayer';
export * from './depthwiseSeparableConv' export * from './depthwiseSeparableConv';
export * from './disposeUnusedWeightTensors' export * from './disposeUnusedWeightTensors';
export * from './extractConvParamsFactory' export * from './extractConvParamsFactory';
export * from './extractFCParamsFactory' export * from './extractFCParamsFactory';
export * from './extractSeparableConvParamsFactory' export * from './extractSeparableConvParamsFactory';
export * from './extractWeightEntryFactory' export * from './extractWeightEntryFactory';
export * from './extractWeightsFactory' export * from './extractWeightsFactory';
export * from './getModelUris' export * from './getModelUris';
export * from './types' export * from './types';

View File

@ -2,11 +2,12 @@ import * as tf from '../../dist/tfjs.esm.js';
import { ConvParams } from './types'; import { ConvParams } from './types';
// eslint-disable-next-line no-unused-vars
export function loadConvParamsFactory(extractWeightEntry: <T>(originalPath: string, paramRank: number) => T) { export function loadConvParamsFactory(extractWeightEntry: <T>(originalPath: string, paramRank: number) => T) {
return function(prefix: string): ConvParams { return (prefix: string): ConvParams => {
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4) const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4);
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
return { filters, bias } return { filters, bias };
} };
} }

View File

@ -1,5 +1,6 @@
import * as tf from '../../dist/tfjs.esm.js'; import * as tf from '../../dist/tfjs.esm.js';
// eslint-disable-next-line no-unused-vars
export type ExtractWeightsFunction = (numWeights: number) => Float32Array export type ExtractWeightsFunction = (numWeights: number) => Float32Array
export type ParamMapping = { export type ParamMapping = {
@ -18,9 +19,14 @@ export type FCParams = {
} }
export class SeparableConvParams { export class SeparableConvParams {
// eslint-disable-next-line no-useless-constructor
constructor( constructor(
// eslint-disable-next-line no-unused-vars
public depthwise_filter: tf.Tensor4D, public depthwise_filter: tf.Tensor4D,
// eslint-disable-next-line no-unused-vars
public pointwise_filter: tf.Tensor4D, public pointwise_filter: tf.Tensor4D,
public bias: tf.Tensor1D // eslint-disable-next-line no-unused-vars
public bias: tf.Tensor1D,
// eslint-disable-next-line no-empty-function
) {} ) {}
} }

View File

@ -3,110 +3,115 @@ import * as tf from '../../dist/tfjs.esm.js';
import { Dimensions } from '../classes/Dimensions'; import { Dimensions } from '../classes/Dimensions';
import { env } from '../env/index'; import { env } from '../env/index';
import { padToSquare } from '../ops/padToSquare'; import { padToSquare } from '../ops/padToSquare';
import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils/index'; import {
computeReshapedDimensions, isTensor3D, isTensor4D, range,
} from '../utils/index';
import { createCanvasFromMedia } from './createCanvas'; import { createCanvasFromMedia } from './createCanvas';
import { imageToSquare } from './imageToSquare'; import { imageToSquare } from './imageToSquare';
import { TResolvedNetInput } from './types'; import { TResolvedNetInput } from './types';
export class NetInput { export class NetInput {
private _imageTensors: Array<tf.Tensor3D | tf.Tensor4D> = [] private _imageTensors: Array<tf.Tensor3D | tf.Tensor4D> = []
private _canvases: HTMLCanvasElement[] = [] private _canvases: HTMLCanvasElement[] = []
private _batchSize: number private _batchSize: number
private _treatAsBatchInput: boolean = false private _treatAsBatchInput: boolean = false
private _inputDimensions: number[][] = [] private _inputDimensions: number[][] = []
private _inputSize: number private _inputSize: number
constructor( constructor(
inputs: Array<TResolvedNetInput>, inputs: Array<TResolvedNetInput>,
treatAsBatchInput: boolean = false treatAsBatchInput: boolean = false,
) { ) {
if (!Array.isArray(inputs)) { if (!Array.isArray(inputs)) {
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`) throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`);
} }
this._treatAsBatchInput = treatAsBatchInput this._treatAsBatchInput = treatAsBatchInput;
this._batchSize = inputs.length this._batchSize = inputs.length;
inputs.forEach((input, idx) => { inputs.forEach((input, idx) => {
if (isTensor3D(input)) { if (isTensor3D(input)) {
this._imageTensors[idx] = input this._imageTensors[idx] = input;
this._inputDimensions[idx] = input.shape this._inputDimensions[idx] = input.shape;
return return;
} }
if (isTensor4D(input)) { if (isTensor4D(input)) {
const batchSize = (input as any).shape[0] const batchSize = (input as any).shape[0];
if (batchSize !== 1) { if (batchSize !== 1) {
throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`) throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
} }
this._imageTensors[idx] = input this._imageTensors[idx] = input;
this._inputDimensions[idx] = (input as any).shape.slice(1) this._inputDimensions[idx] = (input as any).shape.slice(1);
return return;
} }
const canvas = (input as any) instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input) const canvas = (input as any) instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input);
this._canvases[idx] = canvas this._canvases[idx] = canvas;
this._inputDimensions[idx] = [canvas.height, canvas.width, 3] this._inputDimensions[idx] = [canvas.height, canvas.width, 3];
}) });
} }
public get imageTensors(): Array<tf.Tensor3D | tf.Tensor4D> { public get imageTensors(): Array<tf.Tensor3D | tf.Tensor4D> {
return this._imageTensors return this._imageTensors;
} }
public get canvases(): HTMLCanvasElement[] { public get canvases(): HTMLCanvasElement[] {
return this._canvases return this._canvases;
} }
public get isBatchInput(): boolean { public get isBatchInput(): boolean {
return this.batchSize > 1 || this._treatAsBatchInput return this.batchSize > 1 || this._treatAsBatchInput;
} }
public get batchSize(): number { public get batchSize(): number {
return this._batchSize return this._batchSize;
} }
public get inputDimensions(): number[][] { public get inputDimensions(): number[][] {
return this._inputDimensions return this._inputDimensions;
} }
public get inputSize(): number | undefined { public get inputSize(): number | undefined {
return this._inputSize return this._inputSize;
} }
public get reshapedInputDimensions(): Dimensions[] { public get reshapedInputDimensions(): Dimensions[] {
return range(this.batchSize, 0, 1).map( return range(this.batchSize, 0, 1).map(
(_, batchIdx) => this.getReshapedInputDimensions(batchIdx) (_, batchIdx) => this.getReshapedInputDimensions(batchIdx),
) );
} }
public getInput(batchIdx: number): tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement { public getInput(batchIdx: number): tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement {
return this.canvases[batchIdx] || this.imageTensors[batchIdx] return this.canvases[batchIdx] || this.imageTensors[batchIdx];
} }
public getInputDimensions(batchIdx: number): number[] { public getInputDimensions(batchIdx: number): number[] {
return this._inputDimensions[batchIdx] return this._inputDimensions[batchIdx];
} }
public getInputHeight(batchIdx: number): number { public getInputHeight(batchIdx: number): number {
return this._inputDimensions[batchIdx][0] return this._inputDimensions[batchIdx][0];
} }
public getInputWidth(batchIdx: number): number { public getInputWidth(batchIdx: number): number {
return this._inputDimensions[batchIdx][1] return this._inputDimensions[batchIdx][1];
} }
public getReshapedInputDimensions(batchIdx: number): Dimensions { public getReshapedInputDimensions(batchIdx: number): Dimensions {
if (typeof this.inputSize !== 'number') { if (typeof this.inputSize !== 'number') {
throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet') throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet');
} }
const width = this.getInputWidth(batchIdx) const width = this.getInputWidth(batchIdx);
const height = this.getInputHeight(batchIdx) const height = this.getInputHeight(batchIdx);
return computeReshapedDimensions({ width, height }, this.inputSize) return computeReshapedDimensions({ width, height }, this.inputSize);
} }
/** /**
@ -119,39 +124,37 @@ export class NetInput {
* @returns The batch tensor. * @returns The batch tensor.
*/ */
public toBatchTensor(inputSize: number, isCenterInputs: boolean = true): tf.Tensor4D { public toBatchTensor(inputSize: number, isCenterInputs: boolean = true): tf.Tensor4D {
this._inputSize = inputSize;
this._inputSize = inputSize
return tf.tidy(() => { return tf.tidy(() => {
const inputTensors = range(this.batchSize, 0, 1).map((batchIdx) => {
const inputTensors = range(this.batchSize, 0, 1).map(batchIdx => { const input = this.getInput(batchIdx);
const input = this.getInput(batchIdx)
if (input instanceof tf.Tensor) { if (input instanceof tf.Tensor) {
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'. // @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>() let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>();
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'. // @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
imgTensor = padToSquare(imgTensor, isCenterInputs) imgTensor = padToSquare(imgTensor, isCenterInputs);
if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) { if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]) imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]);
} }
return imgTensor.as3D(inputSize, inputSize, 3) return imgTensor.as3D(inputSize, inputSize, 3);
} }
if (input instanceof env.getEnv().Canvas) { if (input instanceof env.getEnv().Canvas) {
return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs)) return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs));
} }
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`) throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`);
}) });
// const batchTensor = tf.stack(inputTensors.map(t => t.toFloat())).as4D(this.batchSize, inputSize, inputSize, 3) // const batchTensor = tf.stack(inputTensors.map(t => t.toFloat())).as4D(this.batchSize, inputSize, inputSize, 3)
const batchTensor = tf.stack(inputTensors.map(t => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3) const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3);
// const batchTensor = tf.stack(inputTensors.map(t => tf.Tensor.as4D(tf.cast(t, 'float32'))), this.batchSize, inputSize, inputSize, 3); // const batchTensor = tf.stack(inputTensors.map(t => tf.Tensor.as4D(tf.cast(t, 'float32'))), this.batchSize, inputSize, inputSize, 3);
return batchTensor return batchTensor;
}) });
} }
} }

View File

@ -2,27 +2,28 @@ import { env } from '../env/index';
import { isMediaLoaded } from './isMediaLoaded'; import { isMediaLoaded } from './isMediaLoaded';
export function awaitMediaLoaded(media: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement) { export function awaitMediaLoaded(media: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement) {
// eslint-disable-next-line consistent-return
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
if (media instanceof env.getEnv().Canvas || isMediaLoaded(media)) { if (media instanceof env.getEnv().Canvas || isMediaLoaded(media)) {
return resolve(null) return resolve(null);
}
function onLoad(e: Event) {
if (!e.currentTarget) return
e.currentTarget.removeEventListener('load', onLoad)
e.currentTarget.removeEventListener('error', onError)
resolve(e)
} }
function onError(e: Event) { function onError(e: Event) {
if (!e.currentTarget) return if (!e.currentTarget) return;
e.currentTarget.removeEventListener('load', onLoad) // eslint-disable-next-line no-use-before-define
e.currentTarget.removeEventListener('error', onError) e.currentTarget.removeEventListener('load', onLoad);
reject(e) e.currentTarget.removeEventListener('error', onError);
reject(e);
} }
media.addEventListener('load', onLoad) function onLoad(e: Event) {
media.addEventListener('error', onError) if (!e.currentTarget) return;
}) e.currentTarget.removeEventListener('load', onLoad);
e.currentTarget.removeEventListener('error', onError);
resolve(e);
}
media.addEventListener('load', onLoad);
media.addEventListener('error', onError);
});
} }

View File

@ -2,22 +2,16 @@ import { env } from '../env/index';
export function bufferToImage(buf: Blob): Promise<HTMLImageElement> { export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
if (!(buf instanceof Blob)) { if (!(buf instanceof Blob)) reject(new Error('bufferToImage - expected buf to be of type: Blob'));
return reject('bufferToImage - expected buf to be of type: Blob') const reader = new FileReader();
}
const reader = new FileReader()
reader.onload = () => { reader.onload = () => {
if (typeof reader.result !== 'string') { if (typeof reader.result !== 'string') reject(new Error('bufferToImage - expected reader.result to be a string, in onload'));
return reject('bufferToImage - expected reader.result to be a string, in onload') const img = env.getEnv().createImageElement();
} img.onload = () => resolve(img);
img.onerror = reject;
const img = env.getEnv().createImageElement() img.src = reader.result as string;
img.onload = () => resolve(img) };
img.onerror = reject reader.onerror = reject;
img.src = reader.result reader.readAsDataURL(buf);
} });
reader.onerror = reject
reader.readAsDataURL(buf)
})
} }

View File

@ -5,29 +5,27 @@ import { getMediaDimensions } from './getMediaDimensions';
import { isMediaLoaded } from './isMediaLoaded'; import { isMediaLoaded } from './isMediaLoaded';
export function createCanvas({ width, height }: IDimensions): HTMLCanvasElement { export function createCanvas({ width, height }: IDimensions): HTMLCanvasElement {
const { createCanvasElement } = env.getEnv();
const { createCanvasElement } = env.getEnv() const canvas = createCanvasElement();
const canvas = createCanvasElement() canvas.width = width;
canvas.width = width canvas.height = height;
canvas.height = height return canvas;
return canvas
} }
export function createCanvasFromMedia(media: HTMLImageElement | HTMLVideoElement | ImageData, dims?: IDimensions): HTMLCanvasElement { export function createCanvasFromMedia(media: HTMLImageElement | HTMLVideoElement | ImageData, dims?: IDimensions): HTMLCanvasElement {
const { ImageData } = env.getEnv();
const { ImageData } = env.getEnv()
if (!(media instanceof ImageData) && !isMediaLoaded(media)) { if (!(media instanceof ImageData) && !isMediaLoaded(media)) {
throw new Error('createCanvasFromMedia - media has not finished loading yet') throw new Error('createCanvasFromMedia - media has not finished loading yet');
} }
const { width, height } = dims || getMediaDimensions(media) const { width, height } = dims || getMediaDimensions(media);
const canvas = createCanvas({ width, height }) const canvas = createCanvas({ width, height });
if (media instanceof ImageData) { if (media instanceof ImageData) {
getContext2dOrThrow(canvas).putImageData(media, 0, 0) getContext2dOrThrow(canvas).putImageData(media, 0, 0);
} else { } else {
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height) getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height);
} }
return canvas return canvas;
} }

View File

@ -16,31 +16,30 @@ import { isTensor3D, isTensor4D } from '../utils/index';
*/ */
export async function extractFaceTensors( export async function extractFaceTensors(
imageTensor: tf.Tensor3D | tf.Tensor4D, imageTensor: tf.Tensor3D | tf.Tensor4D,
detections: Array<FaceDetection | Rect> detections: Array<FaceDetection | Rect>,
): Promise<tf.Tensor3D[]> { ): Promise<tf.Tensor3D[]> {
if (!isTensor3D(imageTensor) && !isTensor4D(imageTensor)) { if (!isTensor3D(imageTensor) && !isTensor4D(imageTensor)) {
throw new Error('extractFaceTensors - expected image tensor to be 3D or 4D') throw new Error('extractFaceTensors - expected image tensor to be 3D or 4D');
} }
if (isTensor4D(imageTensor) && imageTensor.shape[0] > 1) { if (isTensor4D(imageTensor) && imageTensor.shape[0] > 1) {
throw new Error('extractFaceTensors - batchSize > 1 not supported') throw new Error('extractFaceTensors - batchSize > 1 not supported');
} }
return tf.tidy(() => { return tf.tidy(() => {
const [imgHeight, imgWidth, numChannels] = imageTensor.shape.slice(isTensor4D(imageTensor) ? 1 : 0) const [imgHeight, imgWidth, numChannels] = imageTensor.shape.slice(isTensor4D(imageTensor) ? 1 : 0);
const boxes = detections.map( const boxes = detections.map(
det => det instanceof FaceDetection (det) => (det instanceof FaceDetection
? det.forSize(imgWidth, imgHeight).box ? det.forSize(imgWidth, imgHeight).box
: det : det),
) )
.map(box => box.clipAtImageBorders(imgWidth, imgHeight)) .map((box) => box.clipAtImageBorders(imgWidth, imgHeight));
const faceTensors = boxes.map(({ x, y, width, height }) => const faceTensors = boxes.map(({
tf.slice3d(imageTensor.as3D(imgHeight, imgWidth, numChannels), [y, x, 0], [height, width, numChannels]) x, y, width, height,
) }) => tf.slice3d(imageTensor.as3D(imgHeight, imgWidth, numChannels), [y, x, 0], [height, width, numChannels]));
return faceTensors return faceTensors;
}) });
} }

View File

@ -16,38 +16,39 @@ import { TNetInput } from './types';
*/ */
export async function extractFaces( export async function extractFaces(
input: TNetInput, input: TNetInput,
detections: Array<FaceDetection | Rect> detections: Array<FaceDetection | Rect>,
): Promise<HTMLCanvasElement[]> { ): Promise<HTMLCanvasElement[]> {
const { Canvas } = env.getEnv();
const { Canvas } = env.getEnv() let canvas = input as HTMLCanvasElement;
let canvas = input as HTMLCanvasElement
if (!(input instanceof Canvas)) { if (!(input instanceof Canvas)) {
const netInput = await toNetInput(input) const netInput = await toNetInput(input);
if (netInput.batchSize > 1) { if (netInput.batchSize > 1) {
throw new Error('extractFaces - batchSize > 1 not supported') throw new Error('extractFaces - batchSize > 1 not supported');
} }
const tensorOrCanvas = netInput.getInput(0) const tensorOrCanvas = netInput.getInput(0);
canvas = tensorOrCanvas instanceof Canvas canvas = tensorOrCanvas instanceof Canvas
? tensorOrCanvas ? tensorOrCanvas
: await imageTensorToCanvas(tensorOrCanvas) : await imageTensorToCanvas(tensorOrCanvas);
} }
const ctx = getContext2dOrThrow(canvas) const ctx = getContext2dOrThrow(canvas);
const boxes = detections.map( const boxes = detections.map(
det => det instanceof FaceDetection (det) => (det instanceof FaceDetection
? det.forSize(canvas.width, canvas.height).box.floor() ? det.forSize(canvas.width, canvas.height).box.floor()
: det : det),
) )
.map(box => box.clipAtImageBorders(canvas.width, canvas.height)) .map((box) => box.clipAtImageBorders(canvas.width, canvas.height));
return boxes.map(({ x, y, width, height }) => { return boxes.map(({
const faceImg = createCanvas({ width, height }) x, y, width, height,
}) => {
const faceImg = createCanvas({ width, height });
getContext2dOrThrow(faceImg) getContext2dOrThrow(faceImg)
.putImageData(ctx.getImageData(x, y, width, height), 0, 0) .putImageData(ctx.getImageData(x, y, width, height), 0, 0);
return faceImg return faceImg;
}) });
} }

View File

@ -2,11 +2,11 @@ import { bufferToImage } from './bufferToImage';
import { fetchOrThrow } from './fetchOrThrow'; import { fetchOrThrow } from './fetchOrThrow';
export async function fetchImage(uri: string): Promise<HTMLImageElement> { export async function fetchImage(uri: string): Promise<HTMLImageElement> {
const res = await fetchOrThrow(uri) const res = await fetchOrThrow(uri);
const blob = await (res).blob() const blob = await (res).blob();
if (!blob.type.startsWith('image/')) { if (!blob.type.startsWith('image/')) {
throw new Error(`fetchImage - expected blob type to be of type image/*, instead have: ${blob.type}, for url: ${res.url}`) throw new Error(`fetchImage - expected blob type to be of type image/*, instead have: ${blob.type}, for url: ${res.url}`);
} }
return bufferToImage(blob) return bufferToImage(blob);
} }

View File

@ -1,5 +1,5 @@
import { fetchOrThrow } from './fetchOrThrow'; import { fetchOrThrow } from './fetchOrThrow';
export async function fetchJson<T>(uri: string): Promise<T> { export async function fetchJson<T>(uri: string): Promise<T> {
return (await fetchOrThrow(uri)).json() return (await fetchOrThrow(uri)).json();
} }

View File

@ -1,5 +1,5 @@
import { fetchOrThrow } from './fetchOrThrow'; import { fetchOrThrow } from './fetchOrThrow';
export async function fetchNetWeights(uri: string): Promise<Float32Array> { export async function fetchNetWeights(uri: string): Promise<Float32Array> {
return new Float32Array(await (await fetchOrThrow(uri)).arrayBuffer()) return new Float32Array(await (await fetchOrThrow(uri)).arrayBuffer());
} }

View File

@ -2,13 +2,13 @@ import { env } from '../env/index';
export async function fetchOrThrow( export async function fetchOrThrow(
url: string, url: string,
init?: RequestInit // eslint-disable-next-line no-undef
init?: RequestInit,
): Promise<Response> { ): Promise<Response> {
const { fetch } = env.getEnv();
const fetch = env.getEnv().fetch const res = await fetch(url, init);
const res = await fetch(url, init)
if (!(res.status < 400)) { if (!(res.status < 400)) {
throw new Error(`failed to fetch: (${res.status}) ${res.statusText}, from url: ${res.url}`) throw new Error(`failed to fetch: (${res.status}) ${res.statusText}, from url: ${res.url}`);
} }
return res return res;
} }

View File

@ -2,23 +2,22 @@ import { env } from '../env/index';
import { resolveInput } from './resolveInput'; import { resolveInput } from './resolveInput';
export function getContext2dOrThrow(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D): CanvasRenderingContext2D { export function getContext2dOrThrow(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D): CanvasRenderingContext2D {
const { Canvas, CanvasRenderingContext2D } = env.getEnv();
const { Canvas, CanvasRenderingContext2D } = env.getEnv()
if (canvasArg instanceof CanvasRenderingContext2D) { if (canvasArg instanceof CanvasRenderingContext2D) {
return canvasArg return canvasArg;
} }
const canvas = resolveInput(canvasArg) const canvas = resolveInput(canvasArg);
if (!(canvas instanceof Canvas)) { if (!(canvas instanceof Canvas)) {
throw new Error('resolveContext2d - expected canvas to be of instance of Canvas') throw new Error('resolveContext2d - expected canvas to be of instance of Canvas');
} }
const ctx = canvas.getContext('2d') const ctx = canvas.getContext('2d');
if (!ctx) { if (!ctx) {
throw new Error('resolveContext2d - canvas 2d context is null') throw new Error('resolveContext2d - canvas 2d context is null');
} }
return ctx return ctx;
} }

View File

@ -2,14 +2,13 @@ import { Dimensions, IDimensions } from '../classes/Dimensions';
import { env } from '../env/index'; import { env } from '../env/index';
export function getMediaDimensions(input: HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | IDimensions): Dimensions { export function getMediaDimensions(input: HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | IDimensions): Dimensions {
const { Image, Video } = env.getEnv();
const { Image, Video } = env.getEnv()
if (input instanceof Image) { if (input instanceof Image) {
return new Dimensions(input.naturalWidth, input.naturalHeight) return new Dimensions(input.naturalWidth, input.naturalHeight);
} }
if (input instanceof Video) { if (input instanceof Video) {
return new Dimensions(input.videoWidth, input.videoHeight) return new Dimensions(input.videoWidth, input.videoHeight);
} }
return new Dimensions(input.width, input.height) return new Dimensions(input.width, input.height);
} }

View File

@ -5,16 +5,15 @@ import { isTensor4D } from '../utils/index';
export async function imageTensorToCanvas( export async function imageTensorToCanvas(
imgTensor: tf.Tensor, imgTensor: tf.Tensor,
canvas?: HTMLCanvasElement canvas?: HTMLCanvasElement,
): Promise<HTMLCanvasElement> { ): Promise<HTMLCanvasElement> {
const targetCanvas = canvas || env.getEnv().createCanvasElement();
const targetCanvas = canvas || env.getEnv().createCanvasElement() const [height, width, numChannels] = imgTensor.shape.slice(isTensor4D(imgTensor) ? 1 : 0);
const imgTensor3D = tf.tidy(() => imgTensor.as3D(height, width, numChannels).toInt());
await tf.browser.toPixels(imgTensor3D, targetCanvas);
const [height, width, numChannels] = imgTensor.shape.slice(isTensor4D(imgTensor) ? 1 : 0) imgTensor3D.dispose();
const imgTensor3D = tf.tidy(() => imgTensor.as3D(height, width, numChannels).toInt())
await tf.browser.toPixels(imgTensor3D, targetCanvas)
imgTensor3D.dispose() return targetCanvas;
return targetCanvas
} }

View File

@ -4,25 +4,24 @@ import { getContext2dOrThrow } from './getContext2dOrThrow';
import { getMediaDimensions } from './getMediaDimensions'; import { getMediaDimensions } from './getMediaDimensions';
export function imageToSquare(input: HTMLImageElement | HTMLCanvasElement, inputSize: number, centerImage: boolean = false) { export function imageToSquare(input: HTMLImageElement | HTMLCanvasElement, inputSize: number, centerImage: boolean = false) {
const { Image, Canvas } = env.getEnv();
const { Image, Canvas } = env.getEnv()
if (!(input instanceof Image || input instanceof Canvas)) { if (!(input instanceof Image || input instanceof Canvas)) {
throw new Error('imageToSquare - expected arg0 to be HTMLImageElement | HTMLCanvasElement') throw new Error('imageToSquare - expected arg0 to be HTMLImageElement | HTMLCanvasElement');
} }
const dims = getMediaDimensions(input) const dims = getMediaDimensions(input);
const scale = inputSize / Math.max(dims.height, dims.width) const scale = inputSize / Math.max(dims.height, dims.width);
const width = scale * dims.width const width = scale * dims.width;
const height = scale * dims.height const height = scale * dims.height;
const targetCanvas = createCanvas({ width: inputSize, height: inputSize }) const targetCanvas = createCanvas({ width: inputSize, height: inputSize });
const inputCanvas = input instanceof Canvas ? input : createCanvasFromMedia(input) const inputCanvas = input instanceof Canvas ? input : createCanvasFromMedia(input);
const offset = Math.abs(width - height) / 2 const offset = Math.abs(width - height) / 2;
const dx = centerImage && width < height ? offset : 0 const dx = centerImage && width < height ? offset : 0;
const dy = centerImage && height < width ? offset : 0 const dy = centerImage && height < width ? offset : 0;
getContext2dOrThrow(targetCanvas).drawImage(inputCanvas, dx, dy, width, height) getContext2dOrThrow(targetCanvas).drawImage(inputCanvas, dx, dy, width, height);
return targetCanvas return targetCanvas;
} }

View File

@ -1,21 +1,21 @@
export * from './awaitMediaLoaded' export * from './awaitMediaLoaded';
export * from './bufferToImage' export * from './bufferToImage';
export * from './createCanvas' export * from './createCanvas';
export * from './extractFaces' export * from './extractFaces';
export * from './extractFaceTensors' export * from './extractFaceTensors';
export * from './fetchImage' export * from './fetchImage';
export * from './fetchJson' export * from './fetchJson';
export * from './fetchNetWeights' export * from './fetchNetWeights';
export * from './fetchOrThrow' export * from './fetchOrThrow';
export * from './getContext2dOrThrow' export * from './getContext2dOrThrow';
export * from './getMediaDimensions' export * from './getMediaDimensions';
export * from './imageTensorToCanvas' export * from './imageTensorToCanvas';
export * from './imageToSquare' export * from './imageToSquare';
export * from './isMediaElement' export * from './isMediaElement';
export * from './isMediaLoaded' export * from './isMediaLoaded';
export * from './loadWeightMap' export * from './loadWeightMap';
export * from './matchDimensions' export * from './matchDimensions';
export * from './NetInput' export * from './NetInput';
export * from './resolveInput' export * from './resolveInput';
export * from './toNetInput' export * from './toNetInput';
export * from './types' export * from './types';

View File

@ -1,10 +1,9 @@
import { env } from '../env/index'; import { env } from '../env/index';
export function isMediaElement(input: any) { export function isMediaElement(input: any) {
const { Image, Canvas, Video } = env.getEnv();
const { Image, Canvas, Video } = env.getEnv()
return input instanceof Image return input instanceof Image
|| input instanceof Canvas || input instanceof Canvas
|| input instanceof Video || input instanceof Video;
} }

View File

@ -1,9 +1,8 @@
import { env } from '../env/index'; import { env } from '../env/index';
export function isMediaLoaded(media: HTMLImageElement | HTMLVideoElement) : boolean { export function isMediaLoaded(media: HTMLImageElement | HTMLVideoElement) : boolean {
const { Image, Video } = env.getEnv();
const { Image, Video } = env.getEnv()
return (media instanceof Image && media.complete) return (media instanceof Image && media.complete)
|| (media instanceof Video && media.readyState >= 3) || (media instanceof Video && media.readyState >= 3);
} }

View File

@ -7,8 +7,8 @@ export async function loadWeightMap(
uri: string | undefined, uri: string | undefined,
defaultModelName: string, defaultModelName: string,
): Promise<tf.NamedTensorMap> { ): Promise<tf.NamedTensorMap> {
const { manifestUri, modelBaseUri } = getModelUris(uri, defaultModelName) const { manifestUri, modelBaseUri } = getModelUris(uri, defaultModelName);
let manifest = await fetchJson<tf.io.WeightsManifestConfig>(manifestUri) const manifest = await fetchJson<tf.io.WeightsManifestConfig>(manifestUri);
// if (manifest['weightsManifest']) manifest = manifest['weightsManifest']; // if (manifest['weightsManifest']) manifest = manifest['weightsManifest'];
return tf.io.loadWeights(manifest, modelBaseUri) return tf.io.loadWeights(manifest, modelBaseUri);
} }

View File

@ -4,8 +4,8 @@ import { getMediaDimensions } from './getMediaDimensions';
export function matchDimensions(input: IDimensions, reference: IDimensions, useMediaDimensions: boolean = false) { export function matchDimensions(input: IDimensions, reference: IDimensions, useMediaDimensions: boolean = false) {
const { width, height } = useMediaDimensions const { width, height } = useMediaDimensions
? getMediaDimensions(reference) ? getMediaDimensions(reference)
: reference : reference;
input.width = width input.width = width;
input.height = height input.height = height;
return { width, height } return { width, height };
} }

View File

@ -2,7 +2,7 @@ import { env } from '../env/index';
export function resolveInput(arg: string | any) { export function resolveInput(arg: string | any) {
if (!env.isNodejs() && typeof arg === 'string') { if (!env.isNodejs() && typeof arg === 'string') {
return document.getElementById(arg) return document.getElementById(arg);
} }
return arg return arg;
} }

View File

@ -14,44 +14,43 @@ import { TNetInput } from './types';
*/ */
export async function toNetInput(inputs: TNetInput): Promise<NetInput> { export async function toNetInput(inputs: TNetInput): Promise<NetInput> {
if (inputs instanceof NetInput) { if (inputs instanceof NetInput) {
return inputs return inputs;
} }
let inputArgArray = Array.isArray(inputs) const inputArgArray = Array.isArray(inputs)
? inputs ? inputs
: [inputs] : [inputs];
if (!inputArgArray.length) { if (!inputArgArray.length) {
throw new Error('toNetInput - empty array passed as input') throw new Error('toNetInput - empty array passed as input');
} }
const getIdxHint = (idx: number) => Array.isArray(inputs) ? ` at input index ${idx}:` : '' const getIdxHint = (idx: number) => (Array.isArray(inputs) ? ` at input index ${idx}:` : '');
const inputArray = inputArgArray.map(resolveInput) const inputArray = inputArgArray.map(resolveInput);
inputArray.forEach((input, i) => { inputArray.forEach((input, i) => {
if (!isMediaElement(input) && !isTensor3D(input) && !isTensor4D(input)) { if (!isMediaElement(input) && !isTensor3D(input) && !isTensor4D(input)) {
if (typeof inputArgArray[i] === 'string') { if (typeof inputArgArray[i] === 'string') {
throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`) throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`);
} }
throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`) throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`);
} }
if (isTensor4D(input)) { if (isTensor4D(input)) {
// if tf.Tensor4D is passed in the input array, the batch size has to be 1 // if tf.Tensor4D is passed in the input array, the batch size has to be 1
const batchSize = input.shape[0] const batchSize = input.shape[0];
if (batchSize !== 1) { if (batchSize !== 1) {
throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`) throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
} }
} }
}) });
// wait for all media elements being loaded // wait for all media elements being loaded
await Promise.all( await Promise.all(
inputArray.map(input => isMediaElement(input) && awaitMediaLoaded(input)) inputArray.map((input) => isMediaElement(input) && awaitMediaLoaded(input)),
) );
return new NetInput(inputArray, Array.isArray(inputs)) return new NetInput(inputArray, Array.isArray(inputs));
} }

View File

@ -1,6 +1,9 @@
/* eslint-disable max-classes-per-file */
import { Box, IBoundingBox, IRect } from '../classes/index'; import { Box, IBoundingBox, IRect } from '../classes/index';
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow'; import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
import { AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions } from './DrawTextField'; import {
AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions,
} from './DrawTextField';
export interface IDrawBoxOptions { export interface IDrawBoxOptions {
boxColor?: string boxColor?: string
@ -11,49 +14,57 @@ export interface IDrawBoxOptions {
export class DrawBoxOptions { export class DrawBoxOptions {
public boxColor: string public boxColor: string
public lineWidth: number public lineWidth: number
public drawLabelOptions: DrawTextFieldOptions public drawLabelOptions: DrawTextFieldOptions
public label?: string public label?: string
constructor(options: IDrawBoxOptions = {}) { constructor(options: IDrawBoxOptions = {}) {
const { boxColor, lineWidth, label, drawLabelOptions } = options const {
this.boxColor = boxColor || 'rgba(0, 0, 255, 1)' boxColor, lineWidth, label, drawLabelOptions,
this.lineWidth = lineWidth || 2 } = options;
this.label = label this.boxColor = boxColor || 'rgba(0, 0, 255, 1)';
this.lineWidth = lineWidth || 2;
this.label = label;
const defaultDrawLabelOptions = { const defaultDrawLabelOptions = {
anchorPosition: AnchorPosition.BOTTOM_LEFT, anchorPosition: AnchorPosition.BOTTOM_LEFT,
backgroundColor: this.boxColor backgroundColor: this.boxColor,
} };
this.drawLabelOptions = new DrawTextFieldOptions(Object.assign({}, defaultDrawLabelOptions, drawLabelOptions)) this.drawLabelOptions = new DrawTextFieldOptions({ ...defaultDrawLabelOptions, ...drawLabelOptions });
} }
} }
export class DrawBox { export class DrawBox {
public box: Box public box: Box
public options: DrawBoxOptions public options: DrawBoxOptions
constructor( constructor(
box: IBoundingBox | IRect, box: IBoundingBox | IRect,
options: IDrawBoxOptions = {} options: IDrawBoxOptions = {},
) { ) {
this.box = new Box(box) this.box = new Box(box);
this.options = new DrawBoxOptions(options) this.options = new DrawBoxOptions(options);
} }
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) { draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
const ctx = getContext2dOrThrow(canvasArg) const ctx = getContext2dOrThrow(canvasArg);
const { boxColor, lineWidth } = this.options const { boxColor, lineWidth } = this.options;
const { x, y, width, height } = this.box const {
ctx.strokeStyle = boxColor x, y, width, height,
ctx.lineWidth = lineWidth } = this.box;
ctx.strokeRect(x, y, width, height) ctx.strokeStyle = boxColor;
ctx.lineWidth = lineWidth;
ctx.strokeRect(x, y, width, height);
const { label } = this.options const { label } = this.options;
if (label) { if (label) {
new DrawTextField([label], { x: x - (lineWidth / 2), y }, this.options.drawLabelOptions).draw(canvasArg) new DrawTextField([label], { x: x - (lineWidth / 2), y }, this.options.drawLabelOptions).draw(canvasArg);
} }
} }
} }

View File

@ -1,3 +1,4 @@
/* eslint-disable max-classes-per-file */
import { IPoint } from '../classes/index'; import { IPoint } from '../classes/index';
import { FaceLandmarks } from '../classes/FaceLandmarks'; import { FaceLandmarks } from '../classes/FaceLandmarks';
import { FaceLandmarks68 } from '../classes/FaceLandmarks68'; import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
@ -17,62 +18,72 @@ export interface IDrawFaceLandmarksOptions {
export class DrawFaceLandmarksOptions { export class DrawFaceLandmarksOptions {
public drawLines: boolean public drawLines: boolean
public drawPoints: boolean public drawPoints: boolean
public lineWidth: number public lineWidth: number
public pointSize: number public pointSize: number
public lineColor: string public lineColor: string
public pointColor: string public pointColor: string
constructor(options: IDrawFaceLandmarksOptions = {}) { constructor(options: IDrawFaceLandmarksOptions = {}) {
const { drawLines = true, drawPoints = true, lineWidth, lineColor, pointSize, pointColor } = options const {
this.drawLines = drawLines drawLines = true, drawPoints = true, lineWidth, lineColor, pointSize, pointColor,
this.drawPoints = drawPoints } = options;
this.lineWidth = lineWidth || 1 this.drawLines = drawLines;
this.pointSize = pointSize || 2 this.drawPoints = drawPoints;
this.lineColor = lineColor || 'rgba(0, 255, 255, 1)' this.lineWidth = lineWidth || 1;
this.pointColor = pointColor || 'rgba(255, 0, 255, 1)' this.pointSize = pointSize || 2;
this.lineColor = lineColor || 'rgba(0, 255, 255, 1)';
this.pointColor = pointColor || 'rgba(255, 0, 255, 1)';
} }
} }
export class DrawFaceLandmarks { export class DrawFaceLandmarks {
public faceLandmarks: FaceLandmarks public faceLandmarks: FaceLandmarks
public options: DrawFaceLandmarksOptions public options: DrawFaceLandmarksOptions
constructor( constructor(
faceLandmarks: FaceLandmarks, faceLandmarks: FaceLandmarks,
options: IDrawFaceLandmarksOptions = {} options: IDrawFaceLandmarksOptions = {},
) { ) {
this.faceLandmarks = faceLandmarks this.faceLandmarks = faceLandmarks;
this.options = new DrawFaceLandmarksOptions(options) this.options = new DrawFaceLandmarksOptions(options);
} }
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) { draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
const ctx = getContext2dOrThrow(canvasArg) const ctx = getContext2dOrThrow(canvasArg);
const { drawLines, drawPoints, lineWidth, lineColor, pointSize, pointColor } = this.options const {
drawLines, drawPoints, lineWidth, lineColor, pointSize, pointColor,
} = this.options;
if (drawLines && this.faceLandmarks instanceof FaceLandmarks68) { if (drawLines && this.faceLandmarks instanceof FaceLandmarks68) {
ctx.strokeStyle = lineColor ctx.strokeStyle = lineColor;
ctx.lineWidth = lineWidth ctx.lineWidth = lineWidth;
drawContour(ctx, this.faceLandmarks.getJawOutline()) drawContour(ctx, this.faceLandmarks.getJawOutline());
drawContour(ctx, this.faceLandmarks.getLeftEyeBrow()) drawContour(ctx, this.faceLandmarks.getLeftEyeBrow());
drawContour(ctx, this.faceLandmarks.getRightEyeBrow()) drawContour(ctx, this.faceLandmarks.getRightEyeBrow());
drawContour(ctx, this.faceLandmarks.getNose()) drawContour(ctx, this.faceLandmarks.getNose());
drawContour(ctx, this.faceLandmarks.getLeftEye(), true) drawContour(ctx, this.faceLandmarks.getLeftEye(), true);
drawContour(ctx, this.faceLandmarks.getRightEye(), true) drawContour(ctx, this.faceLandmarks.getRightEye(), true);
drawContour(ctx, this.faceLandmarks.getMouth(), true) drawContour(ctx, this.faceLandmarks.getMouth(), true);
} }
if (drawPoints) { if (drawPoints) {
ctx.strokeStyle = pointColor ctx.strokeStyle = pointColor;
ctx.fillStyle = pointColor ctx.fillStyle = pointColor;
const drawPoint = (pt: IPoint) => { const drawPoint = (pt: IPoint) => {
ctx.beginPath() ctx.beginPath();
ctx.arc(pt.x, pt.y, pointSize, 0, 2 * Math.PI) ctx.arc(pt.x, pt.y, pointSize, 0, 2 * Math.PI);
ctx.fill() ctx.fill();
} };
this.faceLandmarks.positions.forEach(drawPoint) this.faceLandmarks.positions.forEach(drawPoint);
} }
} }
} }
@ -81,17 +92,18 @@ export type DrawFaceLandmarksInput = FaceLandmarks | WithFaceLandmarks<WithFaceD
export function drawFaceLandmarks( export function drawFaceLandmarks(
canvasArg: string | HTMLCanvasElement, canvasArg: string | HTMLCanvasElement,
faceLandmarks: DrawFaceLandmarksInput | Array<DrawFaceLandmarksInput> faceLandmarks: DrawFaceLandmarksInput | Array<DrawFaceLandmarksInput>,
) { ) {
const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks] const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks];
faceLandmarksArray.forEach(f => { faceLandmarksArray.forEach((f) => {
// eslint-disable-next-line no-nested-ternary
const landmarks = f instanceof FaceLandmarks const landmarks = f instanceof FaceLandmarks
? f ? f
: (isWithFaceLandmarks(f) ? f.landmarks : undefined) : (isWithFaceLandmarks(f) ? f.landmarks : undefined);
if (!landmarks) { if (!landmarks) {
throw new Error('drawFaceLandmarks - expected faceExpressions to be FaceLandmarks | WithFaceLandmarks<WithFaceDetection<{}>> or array thereof') throw new Error('drawFaceLandmarks - expected faceExpressions to be FaceLandmarks | WithFaceLandmarks<WithFaceDetection<{}>> or array thereof');
} }
new DrawFaceLandmarks(landmarks).draw(canvasArg) new DrawFaceLandmarks(landmarks).draw(canvasArg);
}) });
} }

View File

@ -1,11 +1,17 @@
/* eslint-disable max-classes-per-file */
import { IDimensions, IPoint } from '../classes/index'; import { IDimensions, IPoint } from '../classes/index';
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow'; import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
import { resolveInput } from '../dom/resolveInput'; import { resolveInput } from '../dom/resolveInput';
// eslint-disable-next-line no-shadow
export enum AnchorPosition { export enum AnchorPosition {
// eslint-disable-next-line no-unused-vars
TOP_LEFT = 'TOP_LEFT', TOP_LEFT = 'TOP_LEFT',
// eslint-disable-next-line no-unused-vars
TOP_RIGHT = 'TOP_RIGHT', TOP_RIGHT = 'TOP_RIGHT',
// eslint-disable-next-line no-unused-vars
BOTTOM_LEFT = 'BOTTOM_LEFT', BOTTOM_LEFT = 'BOTTOM_LEFT',
// eslint-disable-next-line no-unused-vars
BOTTOM_RIGHT = 'BOTTOM_RIGHT' BOTTOM_RIGHT = 'BOTTOM_RIGHT'
} }
@ -20,89 +26,101 @@ export interface IDrawTextFieldOptions {
export class DrawTextFieldOptions implements IDrawTextFieldOptions { export class DrawTextFieldOptions implements IDrawTextFieldOptions {
public anchorPosition: AnchorPosition public anchorPosition: AnchorPosition
public backgroundColor: string public backgroundColor: string
public fontColor: string public fontColor: string
public fontSize: number public fontSize: number
public fontStyle: string public fontStyle: string
public padding: number public padding: number
constructor(options: IDrawTextFieldOptions = {}) { constructor(options: IDrawTextFieldOptions = {}) {
const { anchorPosition, backgroundColor, fontColor, fontSize, fontStyle, padding } = options const {
this.anchorPosition = anchorPosition || AnchorPosition.TOP_LEFT anchorPosition, backgroundColor, fontColor, fontSize, fontStyle, padding,
this.backgroundColor = backgroundColor || 'rgba(0, 0, 0, 0.5)' } = options;
this.fontColor = fontColor || 'rgba(255, 255, 255, 1)' this.anchorPosition = anchorPosition || AnchorPosition.TOP_LEFT;
this.fontSize = fontSize || 14 this.backgroundColor = backgroundColor || 'rgba(0, 0, 0, 0.5)';
this.fontStyle = fontStyle || 'Georgia' this.fontColor = fontColor || 'rgba(255, 255, 255, 1)';
this.padding = padding || 4 this.fontSize = fontSize || 14;
this.fontStyle = fontStyle || 'Georgia';
this.padding = padding || 4;
} }
} }
export class DrawTextField { export class DrawTextField {
public text: string[] public text: string[]
public anchor : IPoint public anchor : IPoint
public options: DrawTextFieldOptions public options: DrawTextFieldOptions
constructor( constructor(
text: string | string[] | DrawTextField, text: string | string[] | DrawTextField,
anchor: IPoint, anchor: IPoint,
options: IDrawTextFieldOptions = {} options: IDrawTextFieldOptions = {},
) { ) {
// eslint-disable-next-line no-nested-ternary
this.text = typeof text === 'string' this.text = typeof text === 'string'
? [text] ? [text]
: (text instanceof DrawTextField ? text.text : text) : (text instanceof DrawTextField ? text.text : text);
this.anchor = anchor this.anchor = anchor;
this.options = new DrawTextFieldOptions(options) this.options = new DrawTextFieldOptions(options);
} }
measureWidth(ctx: CanvasRenderingContext2D): number { measureWidth(ctx: CanvasRenderingContext2D): number {
const { padding } = this.options const { padding } = this.options;
return this.text.map(l => ctx.measureText(l).width).reduce((w0, w1) => w0 < w1 ? w1 : w0, 0) + (2 * padding) return this.text.map((l) => ctx.measureText(l).width).reduce((w0, w1) => (w0 < w1 ? w1 : w0), 0) + (2 * padding);
} }
measureHeight(): number { measureHeight(): number {
const { fontSize, padding } = this.options const { fontSize, padding } = this.options;
return this.text.length * fontSize + (2 * padding) return this.text.length * fontSize + (2 * padding);
} }
getUpperLeft(ctx: CanvasRenderingContext2D, canvasDims?: IDimensions): IPoint { getUpperLeft(ctx: CanvasRenderingContext2D, canvasDims?: IDimensions): IPoint {
const { anchorPosition } = this.options const { anchorPosition } = this.options;
const isShiftLeft = anchorPosition === AnchorPosition.BOTTOM_RIGHT || anchorPosition === AnchorPosition.TOP_RIGHT const isShiftLeft = anchorPosition === AnchorPosition.BOTTOM_RIGHT || anchorPosition === AnchorPosition.TOP_RIGHT;
const isShiftTop = anchorPosition === AnchorPosition.BOTTOM_LEFT || anchorPosition === AnchorPosition.BOTTOM_RIGHT const isShiftTop = anchorPosition === AnchorPosition.BOTTOM_LEFT || anchorPosition === AnchorPosition.BOTTOM_RIGHT;
const textFieldWidth = this.measureWidth(ctx) const textFieldWidth = this.measureWidth(ctx);
const textFieldHeight = this.measureHeight() const textFieldHeight = this.measureHeight();
const x = (isShiftLeft ? this.anchor.x - textFieldWidth : this.anchor.x) const x = (isShiftLeft ? this.anchor.x - textFieldWidth : this.anchor.x);
const y = isShiftTop ? this.anchor.y - textFieldHeight : this.anchor.y const y = isShiftTop ? this.anchor.y - textFieldHeight : this.anchor.y;
// adjust anchor if text box exceeds canvas borders // adjust anchor if text box exceeds canvas borders
if (canvasDims) { if (canvasDims) {
const { width, height } = canvasDims const { width, height } = canvasDims;
const newX = Math.max(Math.min(x, width - textFieldWidth), 0) const newX = Math.max(Math.min(x, width - textFieldWidth), 0);
const newY = Math.max(Math.min(y, height - textFieldHeight), 0) const newY = Math.max(Math.min(y, height - textFieldHeight), 0);
return { x: newX, y: newY } return { x: newX, y: newY };
} }
return { x, y } return { x, y };
} }
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) { draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
const canvas = resolveInput(canvasArg) const canvas = resolveInput(canvasArg);
const ctx = getContext2dOrThrow(canvas) const ctx = getContext2dOrThrow(canvas);
const { backgroundColor, fontColor, fontSize, fontStyle, padding } = this.options const {
backgroundColor, fontColor, fontSize, fontStyle, padding,
} = this.options;
ctx.font = `${fontSize}px ${fontStyle}` ctx.font = `${fontSize}px ${fontStyle}`;
const maxTextWidth = this.measureWidth(ctx) const maxTextWidth = this.measureWidth(ctx);
const textHeight = this.measureHeight() const textHeight = this.measureHeight();
ctx.fillStyle = backgroundColor ctx.fillStyle = backgroundColor;
const upperLeft = this.getUpperLeft(ctx, canvas) const upperLeft = this.getUpperLeft(ctx, canvas);
ctx.fillRect(upperLeft.x, upperLeft.y, maxTextWidth, textHeight) ctx.fillRect(upperLeft.x, upperLeft.y, maxTextWidth, textHeight);
ctx.fillStyle = fontColor; ctx.fillStyle = fontColor;
this.text.forEach((textLine, i) => { this.text.forEach((textLine, i) => {
const x = padding + upperLeft.x const x = padding + upperLeft.x;
const y = padding + upperLeft.y + ((i + 1) * fontSize) const y = padding + upperLeft.y + ((i + 1) * fontSize);
ctx.fillText(textLine, x, y) ctx.fillText(textLine, x, y);
}) });
} }
} }

View File

@ -3,26 +3,26 @@ import { Point } from '../classes/index';
export function drawContour( export function drawContour(
ctx: CanvasRenderingContext2D, ctx: CanvasRenderingContext2D,
points: Point[], points: Point[],
isClosed: boolean = false isClosed: boolean = false,
) { ) {
ctx.beginPath() ctx.beginPath();
points.slice(1).forEach(({ x, y }, prevIdx) => { points.slice(1).forEach(({ x, y }, prevIdx) => {
const from = points[prevIdx] const from = points[prevIdx];
ctx.moveTo(from.x, from.y) ctx.moveTo(from.x, from.y);
ctx.lineTo(x, y) ctx.lineTo(x, y);
}) });
if (isClosed) { if (isClosed) {
const from = points[points.length - 1] const from = points[points.length - 1];
const to = points[0] const to = points[0];
if (!from || !to) { if (!from || !to) {
return return;
} }
ctx.moveTo(from.x, from.y) ctx.moveTo(from.x, from.y);
ctx.lineTo(to.x, to.y) ctx.lineTo(to.x, to.y);
} }
ctx.stroke() ctx.stroke();
} }

View File

@ -8,20 +8,22 @@ export type TDrawDetectionsInput = IRect | IBoundingBox | FaceDetection | WithFa
export function drawDetections( export function drawDetections(
canvasArg: string | HTMLCanvasElement, canvasArg: string | HTMLCanvasElement,
detections: TDrawDetectionsInput | Array<TDrawDetectionsInput> detections: TDrawDetectionsInput | Array<TDrawDetectionsInput>,
) { ) {
const detectionsArray = Array.isArray(detections) ? detections : [detections] const detectionsArray = Array.isArray(detections) ? detections : [detections];
detectionsArray.forEach(det => { detectionsArray.forEach((det) => {
// eslint-disable-next-line no-nested-ternary
const score = det instanceof FaceDetection const score = det instanceof FaceDetection
? det.score ? det.score
: (isWithFaceDetection(det) ? det.detection.score : undefined) : (isWithFaceDetection(det) ? det.detection.score : undefined);
// eslint-disable-next-line no-nested-ternary
const box = det instanceof FaceDetection const box = det instanceof FaceDetection
? det.box ? det.box
: (isWithFaceDetection(det) ? det.detection.box : new Box(det)) : (isWithFaceDetection(det) ? det.detection.box : new Box(det));
const label = score ? `${round(score)}` : undefined const label = score ? `${round(score)}` : undefined;
new DrawBox(box, { label }).draw(canvasArg) new DrawBox(box, { label }).draw(canvasArg);
}) });
} }

View File

@ -11,29 +11,30 @@ export function drawFaceExpressions(
canvasArg: string | HTMLCanvasElement, canvasArg: string | HTMLCanvasElement,
faceExpressions: DrawFaceExpressionsInput | Array<DrawFaceExpressionsInput>, faceExpressions: DrawFaceExpressionsInput | Array<DrawFaceExpressionsInput>,
minConfidence = 0.1, minConfidence = 0.1,
textFieldAnchor?: IPoint textFieldAnchor?: IPoint,
) { ) {
const faceExpressionsArray = Array.isArray(faceExpressions) ? faceExpressions : [faceExpressions] const faceExpressionsArray = Array.isArray(faceExpressions) ? faceExpressions : [faceExpressions];
faceExpressionsArray.forEach(e => { faceExpressionsArray.forEach((e) => {
// eslint-disable-next-line no-nested-ternary
const expr = e instanceof FaceExpressions const expr = e instanceof FaceExpressions
? e ? e
: (isWithFaceExpressions(e) ? e.expressions : undefined) : (isWithFaceExpressions(e) ? e.expressions : undefined);
if (!expr) { if (!expr) {
throw new Error('drawFaceExpressions - expected faceExpressions to be FaceExpressions | WithFaceExpressions<{}> or array thereof') throw new Error('drawFaceExpressions - expected faceExpressions to be FaceExpressions | WithFaceExpressions<{}> or array thereof');
} }
const sorted = expr.asSortedArray() const sorted = expr.asSortedArray();
const resultsToDisplay = sorted.filter(expr => expr.probability > minConfidence) const resultsToDisplay = sorted.filter((exprLocal) => exprLocal.probability > minConfidence);
const anchor = isWithFaceDetection(e) const anchor = isWithFaceDetection(e)
? e.detection.box.bottomLeft ? e.detection.box.bottomLeft
: (textFieldAnchor || new Point(0, 0)) : (textFieldAnchor || new Point(0, 0));
const drawTextField = new DrawTextField( const drawTextField = new DrawTextField(
resultsToDisplay.map(expr => `${expr.expression} (${round(expr.probability)})`), resultsToDisplay.map((exprLocal) => `${exprLocal.expression} (${round(exprLocal.probability)})`),
anchor anchor,
) );
drawTextField.draw(canvasArg) drawTextField.draw(canvasArg);
}) });
} }

View File

@ -1,6 +1,6 @@
export * from './drawContour' export * from './drawContour';
export * from './drawDetections' export * from './drawDetections';
export * from './drawFaceExpressions' export * from './drawFaceExpressions';
export * from './DrawBox' export * from './DrawBox';
export * from './DrawFaceLandmarks' export * from './DrawFaceLandmarks';
export * from './DrawTextField' export * from './DrawTextField';

View File

@ -1,24 +1,22 @@
import { Environment } from './types'; import { Environment } from './types';
export function createBrowserEnv(): Environment { export function createBrowserEnv(): Environment {
const fetch = window.fetch;
if (!fetch) throw new Error('fetch - missing fetch implementation for browser environment');
const fetch = window['fetch'] || function() { const readFile = () => {
throw new Error('fetch - missing fetch implementation for browser environment') throw new Error('readFile - filesystem not available for browser environment');
} };
const readFile = function() {
throw new Error('readFile - filesystem not available for browser environment')
}
return { return {
Canvas: HTMLCanvasElement, Canvas: HTMLCanvasElement,
CanvasRenderingContext2D: CanvasRenderingContext2D, CanvasRenderingContext2D,
Image: HTMLImageElement, Image: HTMLImageElement,
ImageData: ImageData, ImageData,
Video: HTMLVideoElement, Video: HTMLVideoElement,
createCanvasElement: () => document.createElement('canvas'), createCanvasElement: () => document.createElement('canvas'),
createImageElement: () => document.createElement('img'), createImageElement: () => document.createElement('img'),
fetch, fetch,
readFile readFile,
} };
} }

View File

@ -1,30 +1,26 @@
import { FileSystem } from './types'; import { FileSystem } from './types';
export function createFileSystem(fs?: any): FileSystem { export function createFileSystem(fs?: any): FileSystem {
let requireFsError = '';
let requireFsError = ''
if (!fs) { if (!fs) {
try { try {
fs = require('fs') // eslint-disable-next-line global-require
fs = require('fs');
} catch (err) { } catch (err) {
requireFsError = err.toString() requireFsError = err.toString();
} }
} }
const readFile = fs const readFile = fs
? function(filePath: string) { ? (filePath: string) => new Promise<Buffer>((resolve, reject) => {
return new Promise<Buffer>((res, rej) => { fs.readFile(filePath, (err: any, buffer: Buffer) => (err ? reject(err) : resolve(buffer)));
fs.readFile(filePath, function(err: any, buffer: Buffer) { })
return err ? rej(err) : res(buffer) : () => {
}) throw new Error(`readFile - failed to require fs in nodejs environment with error: ${requireFsError}`);
}) };
}
: function() {
throw new Error(`readFile - failed to require fs in nodejs environment with error: ${requireFsError}`)
}
return { return {
readFile readFile,
} };
} }

View File

@ -1,40 +1,36 @@
/* eslint-disable max-classes-per-file */
import { createFileSystem } from './createFileSystem'; import { createFileSystem } from './createFileSystem';
import { Environment } from './types'; import { Environment } from './types';
export function createNodejsEnv(): Environment { export function createNodejsEnv(): Environment {
// eslint-disable-next-line dot-notation
const Canvas = global['Canvas'] || global.HTMLCanvasElement;
const Image = global.Image || global.HTMLImageElement;
const Canvas = global['Canvas'] || global['HTMLCanvasElement'] const createCanvasElement = () => {
const Image = global['Image'] || global['HTMLImageElement'] if (Canvas) return new Canvas();
throw new Error('createCanvasElement - missing Canvas implementation for nodejs environment');
};
const createCanvasElement = function() { const createImageElement = () => {
if (Canvas) { if (Image) return new Image();
return new Canvas() throw new Error('createImageElement - missing Image implementation for nodejs environment');
} };
throw new Error('createCanvasElement - missing Canvas implementation for nodejs environment')
}
const createImageElement = function() { const fetch = global.fetch;
if (Image) { // if (!fetch) throw new Error('fetch - missing fetch implementation for nodejs environment');
return new Image()
}
throw new Error('createImageElement - missing Image implementation for nodejs environment')
}
const fetch = global['fetch'] || function() { const fileSystem = createFileSystem();
throw new Error('fetch - missing fetch implementation for nodejs environment')
}
const fileSystem = createFileSystem()
return { return {
Canvas: Canvas || class {}, Canvas: Canvas || class {},
CanvasRenderingContext2D: global['CanvasRenderingContext2D'] || class {}, CanvasRenderingContext2D: global.CanvasRenderingContext2D || class {},
Image: Image || class {}, Image: Image || class {},
ImageData: global['ImageData'] || class {}, ImageData: global.ImageData || class {},
Video: global['HTMLVideoElement'] || class {}, Video: global.HTMLVideoElement || class {},
createCanvasElement, createCanvasElement,
createImageElement, createImageElement,
fetch, fetch,
...fileSystem ...fileSystem,
} };
} }

47
src/env/index.ts vendored
View File

@ -5,49 +5,46 @@ import { isBrowser } from './isBrowser';
import { isNodejs } from './isNodejs'; import { isNodejs } from './isNodejs';
import { Environment } from './types'; import { Environment } from './types';
let environment: Environment | null let environment: Environment | null;
function getEnv(): Environment { function getEnv(): Environment {
if (!environment) { if (!environment) {
throw new Error('getEnv - environment is not defined, check isNodejs() and isBrowser()') throw new Error('getEnv - environment is not defined, check isNodejs() and isBrowser()');
} }
return environment return environment;
} }
function setEnv(env: Environment) { function setEnv(env: Environment) {
environment = env environment = env;
} }
function initialize() { function initialize() {
// check for isBrowser() first to prevent electron renderer process // check for isBrowser() first to prevent electron renderer process
// to be initialized with wrong environment due to isNodejs() returning true // to be initialized with wrong environment due to isNodejs() returning true
if (isBrowser()) { if (isBrowser()) return setEnv(createBrowserEnv());
return setEnv(createBrowserEnv()) if (isNodejs()) return setEnv(createNodejsEnv());
} return null;
if (isNodejs()) {
return setEnv(createNodejsEnv())
}
} }
function monkeyPatch(env: Partial<Environment>) { function monkeyPatch(env: Partial<Environment>) {
if (!environment) { if (!environment) {
initialize() initialize();
} }
if (!environment) { if (!environment) {
throw new Error('monkeyPatch - environment is not defined, check isNodejs() and isBrowser()') throw new Error('monkeyPatch - environment is not defined, check isNodejs() and isBrowser()');
} }
const { Canvas = environment.Canvas, Image = environment.Image } = env const { Canvas = environment.Canvas, Image = environment.Image } = env;
environment.Canvas = Canvas environment.Canvas = Canvas;
environment.Image = Image environment.Image = Image;
environment.createCanvasElement = env.createCanvasElement || (() => new Canvas()) environment.createCanvasElement = env.createCanvasElement || (() => new Canvas());
environment.createImageElement = env.createImageElement || (() => new Image()) environment.createImageElement = env.createImageElement || (() => new Image());
environment.ImageData = env.ImageData || environment.ImageData environment.ImageData = env.ImageData || environment.ImageData;
environment.Video = env.Video || environment.Video environment.Video = env.Video || environment.Video;
environment.fetch = env.fetch || environment.fetch environment.fetch = env.fetch || environment.fetch;
environment.readFile = env.readFile || environment.readFile environment.readFile = env.readFile || environment.readFile;
} }
export const env = { export const env = {
@ -59,9 +56,9 @@ export const env = {
createNodejsEnv, createNodejsEnv,
monkeyPatch, monkeyPatch,
isBrowser, isBrowser,
isNodejs isNodejs,
} };
initialize() initialize();
export * from './types' export * from './types';

View File

@ -5,5 +5,5 @@ export function isBrowser(): boolean {
&& typeof HTMLCanvasElement !== 'undefined' && typeof HTMLCanvasElement !== 'undefined'
&& typeof HTMLVideoElement !== 'undefined' && typeof HTMLVideoElement !== 'undefined'
&& typeof ImageData !== 'undefined' && typeof ImageData !== 'undefined'
&& typeof CanvasRenderingContext2D !== 'undefined' && typeof CanvasRenderingContext2D !== 'undefined';
} }

2
src/env/isNodejs.ts vendored
View File

@ -4,5 +4,5 @@ export function isNodejs(): boolean {
&& typeof module !== 'undefined' && typeof module !== 'undefined'
// issues with gatsby.js: module.exports is undefined // issues with gatsby.js: module.exports is undefined
// && !!module.exports // && !!module.exports
&& typeof process !== 'undefined' && !!process.version && typeof process !== 'undefined' && !!process.version;
} }

2
src/env/types.ts vendored
View File

@ -1,4 +1,5 @@
export type FileSystem = { export type FileSystem = {
// eslint-disable-next-line no-unused-vars
readFile: (filePath: string) => Promise<Buffer> readFile: (filePath: string) => Promise<Buffer>
} }
@ -10,5 +11,6 @@ export type Environment = FileSystem & {
Video: typeof HTMLVideoElement Video: typeof HTMLVideoElement
createCanvasElement: () => HTMLCanvasElement createCanvasElement: () => HTMLCanvasElement
createImageElement: () => HTMLImageElement createImageElement: () => HTMLImageElement
// eslint-disable-next-line no-undef, no-unused-vars
fetch: (url: string, init?: RequestInit) => Promise<Response> fetch: (url: string, init?: RequestInit) => Promise<Response>
} }

View File

@ -7,46 +7,45 @@ import { FaceProcessor } from '../faceProcessor/FaceProcessor';
import { FaceExpressions } from './FaceExpressions'; import { FaceExpressions } from './FaceExpressions';
export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> { export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> {
constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) { constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) {
super('FaceExpressionNet', faceFeatureExtractor) super('FaceExpressionNet', faceFeatureExtractor);
} }
public forwardInput(input: NetInput | tf.Tensor4D): tf.Tensor2D { public forwardInput(input: NetInput | tf.Tensor4D): tf.Tensor2D {
return tf.tidy(() => tf.softmax(this.runNet(input))) return tf.tidy(() => tf.softmax(this.runNet(input)));
} }
public async forward(input: TNetInput): Promise<tf.Tensor2D> { public async forward(input: TNetInput): Promise<tf.Tensor2D> {
return this.forwardInput(await toNetInput(input)) return this.forwardInput(await toNetInput(input));
} }
public async predictExpressions(input: TNetInput) { public async predictExpressions(input: TNetInput) {
const netInput = await toNetInput(input) const netInput = await toNetInput(input);
const out = await this.forwardInput(netInput) const out = await this.forwardInput(netInput);
const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async t => { const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async (t) => {
const data = await t.data() const data = await t.data();
t.dispose() t.dispose();
return data return data;
})) }));
out.dispose() out.dispose();
const predictionsByBatch = probabilitesByBatch const predictionsByBatch = probabilitesByBatch
.map(probabilites => new FaceExpressions(probabilites as Float32Array)) .map((probabilites) => new FaceExpressions(probabilites as Float32Array));
return netInput.isBatchInput return netInput.isBatchInput
? predictionsByBatch ? predictionsByBatch
: predictionsByBatch[0] : predictionsByBatch[0];
} }
protected getDefaultModelName(): string { protected getDefaultModelName(): string {
return 'face_expression_model' return 'face_expression_model';
} }
protected getClassifierChannelsIn(): number { protected getClassifierChannelsIn(): number {
return 256 return 256;
} }
protected getClassifierChannelsOut(): number { protected getClassifierChannelsOut(): number {
return 7 return 7;
} }
} }

View File

@ -1,27 +1,33 @@
export const FACE_EXPRESSION_LABELS = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgusted', 'surprised'] export const FACE_EXPRESSION_LABELS = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgusted', 'surprised'];
export class FaceExpressions { export class FaceExpressions {
public neutral: number public neutral: number
public happy: number public happy: number
public sad: number public sad: number
public angry: number public angry: number
public fearful: number public fearful: number
public disgusted: number public disgusted: number
public surprised: number public surprised: number
constructor(probabilities: number[] | Float32Array) { constructor(probabilities: number[] | Float32Array) {
if (probabilities.length !== 7) { if (probabilities.length !== 7) {
throw new Error(`FaceExpressions.constructor - expected probabilities.length to be 7, have: ${probabilities.length}`) throw new Error(`FaceExpressions.constructor - expected probabilities.length to be 7, have: ${probabilities.length}`);
} }
FACE_EXPRESSION_LABELS.forEach((expression, idx) => { FACE_EXPRESSION_LABELS.forEach((expression, idx) => {
this[expression] = probabilities[idx] this[expression] = probabilities[idx];
}) });
} }
asSortedArray() { asSortedArray() {
return FACE_EXPRESSION_LABELS return FACE_EXPRESSION_LABELS
.map(expression => ({ expression, probability: this[expression] as number })) .map((expression) => ({ expression, probability: this[expression] as number }))
.sort((e0, e1) => e1.probability - e0.probability) .sort((e0, e1) => e1.probability - e0.probability);
} }
} }

View File

@ -9,47 +9,45 @@ import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { FaceFeatureExtractorParams, IFaceFeatureExtractor } from './types'; import { FaceFeatureExtractorParams, IFaceFeatureExtractor } from './types';
export class FaceFeatureExtractor extends NeuralNetwork<FaceFeatureExtractorParams> implements IFaceFeatureExtractor<FaceFeatureExtractorParams> { export class FaceFeatureExtractor extends NeuralNetwork<FaceFeatureExtractorParams> implements IFaceFeatureExtractor<FaceFeatureExtractorParams> {
constructor() { constructor() {
super('FaceFeatureExtractor') super('FaceFeatureExtractor');
} }
public forwardInput(input: NetInput): tf.Tensor4D { public forwardInput(input: NetInput): tf.Tensor4D {
const { params } = this;
const { params } = this
if (!params) { if (!params) {
throw new Error('FaceFeatureExtractor - load model before inference') throw new Error('FaceFeatureExtractor - load model before inference');
} }
return tf.tidy(() => { return tf.tidy(() => {
const batchTensor = tf.cast(input.toBatchTensor(112, true), 'float32'); const batchTensor = tf.cast(input.toBatchTensor(112, true), 'float32');
const meanRgb = [122.782, 117.001, 104.298] const meanRgb = [122.782, 117.001, 104.298];
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D;
let out = denseBlock4(normalized, params.dense0, true) let out = denseBlock4(normalized, params.dense0, true);
out = denseBlock4(out, params.dense1) out = denseBlock4(out, params.dense1);
out = denseBlock4(out, params.dense2) out = denseBlock4(out, params.dense2);
out = denseBlock4(out, params.dense3) out = denseBlock4(out, params.dense3);
out = tf.avgPool(out, [7, 7], [2, 2], 'valid') out = tf.avgPool(out, [7, 7], [2, 2], 'valid');
return out return out;
}) });
} }
public async forward(input: TNetInput): Promise<tf.Tensor4D> { public async forward(input: TNetInput): Promise<tf.Tensor4D> {
return this.forwardInput(await toNetInput(input)) return this.forwardInput(await toNetInput(input));
} }
protected getDefaultModelName(): string { protected getDefaultModelName(): string {
return 'face_feature_extractor_model' return 'face_feature_extractor_model';
} }
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) { protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
return extractParamsFromWeigthMap(weightMap) return extractParamsFromWeigthMap(weightMap);
} }
protected extractParams(weights: Float32Array) { protected extractParams(weights: Float32Array) {
return extractParams(weights) return extractParams(weights);
} }
} }

View File

@ -9,46 +9,44 @@ import { extractParamsTiny } from './extractParamsTiny';
import { IFaceFeatureExtractor, TinyFaceFeatureExtractorParams } from './types'; import { IFaceFeatureExtractor, TinyFaceFeatureExtractorParams } from './types';
export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyFaceFeatureExtractorParams> implements IFaceFeatureExtractor<TinyFaceFeatureExtractorParams> { export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyFaceFeatureExtractorParams> implements IFaceFeatureExtractor<TinyFaceFeatureExtractorParams> {
constructor() { constructor() {
super('TinyFaceFeatureExtractor') super('TinyFaceFeatureExtractor');
} }
public forwardInput(input: NetInput): tf.Tensor4D { public forwardInput(input: NetInput): tf.Tensor4D {
const { params } = this;
const { params } = this
if (!params) { if (!params) {
throw new Error('TinyFaceFeatureExtractor - load model before inference') throw new Error('TinyFaceFeatureExtractor - load model before inference');
} }
return tf.tidy(() => { return tf.tidy(() => {
const batchTensor = tf.cast(input.toBatchTensor(112, true), 'float32'); const batchTensor = tf.cast(input.toBatchTensor(112, true), 'float32');
const meanRgb = [122.782, 117.001, 104.298] const meanRgb = [122.782, 117.001, 104.298];
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D;
let out = denseBlock3(normalized, params.dense0, true) let out = denseBlock3(normalized, params.dense0, true);
out = denseBlock3(out, params.dense1) out = denseBlock3(out, params.dense1);
out = denseBlock3(out, params.dense2) out = denseBlock3(out, params.dense2);
out = tf.avgPool(out, [14, 14], [2, 2], 'valid') out = tf.avgPool(out, [14, 14], [2, 2], 'valid');
return out return out;
}) });
} }
public async forward(input: TNetInput): Promise<tf.Tensor4D> { public async forward(input: TNetInput): Promise<tf.Tensor4D> {
return this.forwardInput(await toNetInput(input)) return this.forwardInput(await toNetInput(input));
} }
protected getDefaultModelName(): string { protected getDefaultModelName(): string {
return 'face_feature_extractor_tiny_model' return 'face_feature_extractor_tiny_model';
} }
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) { protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
return extractParamsFromWeigthMapTiny(weightMap) return extractParamsFromWeigthMapTiny(weightMap);
} }
protected extractParams(weights: Float32Array) { protected extractParams(weights: Float32Array) {
return extractParamsTiny(weights) return extractParamsTiny(weights);
} }
} }

View File

@ -7,49 +7,49 @@ import { DenseBlock3Params, DenseBlock4Params } from './types';
export function denseBlock3( export function denseBlock3(
x: tf.Tensor4D, x: tf.Tensor4D,
denseBlockParams: DenseBlock3Params, denseBlockParams: DenseBlock3Params,
isFirstLayer: boolean = false isFirstLayer: boolean = false,
): tf.Tensor4D { ): tf.Tensor4D {
return tf.tidy(() => { return tf.tidy(() => {
const out1 = tf.relu( const out1 = tf.relu(
isFirstLayer isFirstLayer
? tf.add( ? tf.add(
tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, [2, 2], 'same'), tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, [2, 2], 'same'),
denseBlockParams.conv0.bias denseBlockParams.conv0.bias,
) )
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, [2, 2]) : depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, [2, 2]),
) as tf.Tensor4D ) as tf.Tensor4D;
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]) const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]);
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D;
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1]) const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1]);
return tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D return tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D;
}) });
} }
export function denseBlock4( export function denseBlock4(
x: tf.Tensor4D, x: tf.Tensor4D,
denseBlockParams: DenseBlock4Params, denseBlockParams: DenseBlock4Params,
isFirstLayer: boolean = false, isFirstLayer: boolean = false,
isScaleDown: boolean = true isScaleDown: boolean = true,
): tf.Tensor4D { ): tf.Tensor4D {
return tf.tidy(() => { return tf.tidy(() => {
const out1 = tf.relu( const out1 = tf.relu(
isFirstLayer isFirstLayer
? tf.add( ? tf.add(
tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, isScaleDown ? [2, 2] : [1, 1], 'same'), tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, isScaleDown ? [2, 2] : [1, 1], 'same'),
denseBlockParams.conv0.bias denseBlockParams.conv0.bias,
) )
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, isScaleDown ? [2, 2] : [1, 1]) : depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, isScaleDown ? [2, 2] : [1, 1]),
) as tf.Tensor4D ) as tf.Tensor4D;
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]) const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]);
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D;
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1]) const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1]);
const in4 = tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D const in4 = tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D;
const out4 = depthwiseSeparableConv(in4, denseBlockParams.conv3, [1, 1]) const out4 = depthwiseSeparableConv(in4, denseBlockParams.conv3, [1, 1]);
return tf.relu(tf.add(out1, tf.add(out2, tf.add(out3, out4)))) as tf.Tensor4D return tf.relu(tf.add(out1, tf.add(out2, tf.add(out3, out4)))) as tf.Tensor4D;
}) });
} }

View File

@ -2,31 +2,31 @@ import { extractWeightsFactory, ParamMapping } from '../common/index';
import { extractorsFactory } from './extractorsFactory'; import { extractorsFactory } from './extractorsFactory';
import { FaceFeatureExtractorParams } from './types'; import { FaceFeatureExtractorParams } from './types';
export function extractParams(weights: Float32Array): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } { export function extractParams(weights: Float32Array): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [];
const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights,
} = extractWeightsFactory(weights) } = extractWeightsFactory(weights);
const { const {
extractDenseBlock4Params extractDenseBlock4Params,
} = extractorsFactory(extractWeights, paramMappings) } = extractorsFactory(extractWeights, paramMappings);
const dense0 = extractDenseBlock4Params(3, 32, 'dense0', true) const dense0 = extractDenseBlock4Params(3, 32, 'dense0', true);
const dense1 = extractDenseBlock4Params(32, 64, 'dense1') const dense1 = extractDenseBlock4Params(32, 64, 'dense1');
const dense2 = extractDenseBlock4Params(64, 128, 'dense2') const dense2 = extractDenseBlock4Params(64, 128, 'dense2');
const dense3 = extractDenseBlock4Params(128, 256, 'dense3') const dense3 = extractDenseBlock4Params(128, 256, 'dense3');
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
} }
return { return {
paramMappings, paramMappings,
params: { dense0, dense1, dense2, dense3 } params: {
} dense0, dense1, dense2, dense3,
},
};
} }

View File

@ -5,23 +5,22 @@ import { loadParamsFactory } from './loadParamsFactory';
import { FaceFeatureExtractorParams } from './types'; import { FaceFeatureExtractorParams } from './types';
export function extractParamsFromWeigthMap( export function extractParamsFromWeigthMap(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap,
): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } { ): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [];
const paramMappings: ParamMapping[] = []
const { const {
extractDenseBlock4Params extractDenseBlock4Params,
} = loadParamsFactory(weightMap, paramMappings) } = loadParamsFactory(weightMap, paramMappings);
const params = { const params = {
dense0: extractDenseBlock4Params('dense0', true), dense0: extractDenseBlock4Params('dense0', true),
dense1: extractDenseBlock4Params('dense1'), dense1: extractDenseBlock4Params('dense1'),
dense2: extractDenseBlock4Params('dense2'), dense2: extractDenseBlock4Params('dense2'),
dense3: extractDenseBlock4Params('dense3') dense3: extractDenseBlock4Params('dense3'),
} };
disposeUnusedWeightTensors(weightMap, paramMappings) disposeUnusedWeightTensors(weightMap, paramMappings);
return { params, paramMappings } return { params, paramMappings };
} }

View File

@ -5,22 +5,21 @@ import { loadParamsFactory } from './loadParamsFactory';
import { TinyFaceFeatureExtractorParams } from './types'; import { TinyFaceFeatureExtractorParams } from './types';
export function extractParamsFromWeigthMapTiny( export function extractParamsFromWeigthMapTiny(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap,
): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } { ): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [];
const paramMappings: ParamMapping[] = []
const { const {
extractDenseBlock3Params extractDenseBlock3Params,
} = loadParamsFactory(weightMap, paramMappings) } = loadParamsFactory(weightMap, paramMappings);
const params = { const params = {
dense0: extractDenseBlock3Params('dense0', true), dense0: extractDenseBlock3Params('dense0', true),
dense1: extractDenseBlock3Params('dense1'), dense1: extractDenseBlock3Params('dense1'),
dense2: extractDenseBlock3Params('dense2') dense2: extractDenseBlock3Params('dense2'),
} };
disposeUnusedWeightTensors(weightMap, paramMappings) disposeUnusedWeightTensors(weightMap, paramMappings);
return { params, paramMappings } return { params, paramMappings };
} }

View File

@ -2,31 +2,28 @@ import { extractWeightsFactory, ParamMapping } from '../common/index';
import { extractorsFactory } from './extractorsFactory'; import { extractorsFactory } from './extractorsFactory';
import { TinyFaceFeatureExtractorParams } from './types'; import { TinyFaceFeatureExtractorParams } from './types';
export function extractParamsTiny(weights: Float32Array): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } { export function extractParamsTiny(weights: Float32Array): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [];
const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights,
} = extractWeightsFactory(weights) } = extractWeightsFactory(weights);
const { const {
extractDenseBlock3Params extractDenseBlock3Params,
} = extractorsFactory(extractWeights, paramMappings) } = extractorsFactory(extractWeights, paramMappings);
const dense0 = extractDenseBlock3Params(3, 32, 'dense0', true) const dense0 = extractDenseBlock3Params(3, 32, 'dense0', true);
const dense1 = extractDenseBlock3Params(32, 64, 'dense1') const dense1 = extractDenseBlock3Params(32, 64, 'dense1');
const dense2 = extractDenseBlock3Params(64, 128, 'dense2') const dense2 = extractDenseBlock3Params(64, 128, 'dense2');
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
} }
return { return {
paramMappings, paramMappings,
params: { dense0, dense1, dense2 } params: { dense0, dense1, dense2 },
} };
} }

View File

@ -7,32 +7,30 @@ import {
import { DenseBlock3Params, DenseBlock4Params } from './types'; import { DenseBlock3Params, DenseBlock4Params } from './types';
export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) { export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings);
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings) const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings);
const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings)
function extractDenseBlock3Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock3Params { function extractDenseBlock3Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock3Params {
const conv0 = isFirstLayer const conv0 = isFirstLayer
? extractConvParams(channelsIn, channelsOut, 3, `${mappedPrefix}/conv0`) ? extractConvParams(channelsIn, channelsOut, 3, `${mappedPrefix}/conv0`)
: extractSeparableConvParams(channelsIn, channelsOut, `${mappedPrefix}/conv0`) : extractSeparableConvParams(channelsIn, channelsOut, `${mappedPrefix}/conv0`);
const conv1 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv1`) const conv1 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv1`);
const conv2 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv2`) const conv2 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv2`);
return { conv0, conv1, conv2 } return { conv0, conv1, conv2 };
} }
function extractDenseBlock4Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock4Params { function extractDenseBlock4Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock4Params {
const { conv0, conv1, conv2 } = extractDenseBlock3Params(channelsIn, channelsOut, mappedPrefix, isFirstLayer);
const conv3 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv3`);
const { conv0, conv1, conv2 } = extractDenseBlock3Params(channelsIn, channelsOut, mappedPrefix, isFirstLayer) return {
const conv3 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv3`) conv0, conv1, conv2, conv3,
};
return { conv0, conv1, conv2, conv3 }
} }
return { return {
extractDenseBlock3Params, extractDenseBlock3Params,
extractDenseBlock4Params extractDenseBlock4Params,
} };
} }

Some files were not shown because too many files have changed in this diff Show More