full re-lint and typings generation
parent
2a7bf4080b
commit
0692fbf532
|
@ -0,0 +1,47 @@
|
|||
{
|
||||
"globals": {},
|
||||
"env": {
|
||||
"browser": true,
|
||||
"commonjs": true,
|
||||
"es6": true,
|
||||
"node": true,
|
||||
"es2020": true
|
||||
},
|
||||
"parser": "@typescript-eslint/parser",
|
||||
"parserOptions": { "ecmaVersion": 2020 },
|
||||
"plugins": ["@typescript-eslint"],
|
||||
"extends": [
|
||||
"eslint:recommended",
|
||||
"plugin:import/errors",
|
||||
"plugin:import/warnings",
|
||||
"plugin:import/typescript",
|
||||
"plugin:node/recommended",
|
||||
"plugin:promise/recommended",
|
||||
"plugin:json/recommended-with-comments",
|
||||
"airbnb-base"
|
||||
],
|
||||
"ignorePatterns": [ "node_modules", "types" ],
|
||||
"settings": {
|
||||
"import/resolver": {
|
||||
"node": {
|
||||
"extensions": [".js", ".ts"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"rules": {
|
||||
"max-len": [1, 275, 3],
|
||||
"no-plusplus": "off",
|
||||
"import/prefer-default-export": "off",
|
||||
"node/no-unsupported-features/es-syntax": "off",
|
||||
"import/no-cycle": "off",
|
||||
"import/extensions": "off",
|
||||
"node/no-missing-import": "off",
|
||||
"no-underscore-dangle": "off",
|
||||
"class-methods-use-this": "off",
|
||||
"camelcase": "off",
|
||||
"no-await-in-loop": "off",
|
||||
"no-continue": "off",
|
||||
"no-param-reassign": "off",
|
||||
"prefer-destructuring": "off"
|
||||
}
|
||||
}
|
3
build.js
3
build.js
|
@ -21,7 +21,8 @@ const tsconfig = {
|
|||
noEmitOnError: false,
|
||||
target: ts.ScriptTarget.ES2018,
|
||||
module: ts.ModuleKind.ES2020,
|
||||
outFile: "dist/face-api.d.ts",
|
||||
// outFile: "dist/face-api.d.ts",
|
||||
outDir: "types/",
|
||||
declaration: true,
|
||||
emitDeclarationOnly: true,
|
||||
emitDecoratorMetadata: true,
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -20799,7 +20799,7 @@
|
|||
]
|
||||
},
|
||||
"src/tfjs/tf-browser.js": {
|
||||
"bytes": 1784,
|
||||
"bytes": 1888,
|
||||
"imports": [
|
||||
{
|
||||
"path": "node_modules/@tensorflow/tfjs/dist/index.js"
|
||||
|
@ -20814,7 +20814,7 @@
|
|||
"dist/tfjs.esm.js.map": {
|
||||
"imports": [],
|
||||
"inputs": {},
|
||||
"bytes": 1070617
|
||||
"bytes": 1063970
|
||||
},
|
||||
"dist/tfjs.esm.js": {
|
||||
"imports": [],
|
||||
|
@ -23195,7 +23195,7 @@
|
|||
"bytesInOutput": 29890
|
||||
},
|
||||
"node_modules/@tensorflow/tfjs-layers/dist/layers/convolutional_recurrent.js": {
|
||||
"bytesInOutput": 9235
|
||||
"bytesInOutput": 9196
|
||||
},
|
||||
"node_modules/@tensorflow/tfjs-layers/dist/layers/core.js": {
|
||||
"bytesInOutput": 9961
|
||||
|
@ -24070,11 +24070,8 @@
|
|||
"node_modules/@tensorflow/tfjs-backend-webgl/dist/version.js": {
|
||||
"bytesInOutput": 22
|
||||
},
|
||||
"node_modules/@tensorflow/tfjs-backend-webgl/dist/webgl.js": {
|
||||
"bytesInOutput": 67
|
||||
},
|
||||
"node_modules/@tensorflow/tfjs-backend-webgl/dist/base.js": {
|
||||
"bytesInOutput": 113
|
||||
"bytesInOutput": 85
|
||||
},
|
||||
"node_modules/@tensorflow/tfjs-backend-webgl/dist/index.js": {
|
||||
"bytesInOutput": 0
|
||||
|
@ -25067,7 +25064,7 @@
|
|||
"bytesInOutput": 0
|
||||
}
|
||||
},
|
||||
"bytes": 1572551
|
||||
"bytes": 1572418
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
17
package.json
17
package.json
|
@ -5,13 +5,14 @@
|
|||
"main": "dist/face-api.node.js",
|
||||
"module": "dist/face-api.esm.js",
|
||||
"browser": "dist/face-api.esm.js",
|
||||
"typings": "dist/face-api.d.ts",
|
||||
"types": "types/index.d.ts",
|
||||
"engines": {
|
||||
"node": ">=12.0.0"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "node --trace-warnings example/node.js",
|
||||
"build": "rimraf dist/* && node ./build.js"
|
||||
"start": "node --trace-warnings example/node-singleprocess.js",
|
||||
"build": "rimraf dist/* && node ./build.js",
|
||||
"lint": "eslint src/**/*"
|
||||
},
|
||||
"keywords": [
|
||||
"tensorflow",
|
||||
|
@ -44,7 +45,15 @@
|
|||
"@tensorflow/tfjs-node": "^2.8.1",
|
||||
"@tensorflow/tfjs-node-gpu": "^2.8.1",
|
||||
"@types/node": "^14.14.14",
|
||||
"esbuild": "^0.8.24",
|
||||
"@typescript-eslint/eslint-plugin": "^4.11.0",
|
||||
"@typescript-eslint/parser": "^4.11.0",
|
||||
"esbuild": "^0.8.26",
|
||||
"eslint": "^7.16.0",
|
||||
"eslint-config-airbnb-base": "^14.2.1",
|
||||
"eslint-plugin-import": "^2.22.1",
|
||||
"eslint-plugin-json": "^2.1.2",
|
||||
"eslint-plugin-node": "^11.1.0",
|
||||
"eslint-plugin-promise": "^4.2.1",
|
||||
"rimraf": "^3.0.2",
|
||||
"tslib": "^2.0.3",
|
||||
"typescript": "^4.1.3"
|
||||
|
|
|
@ -5,120 +5,118 @@ import { seperateWeightMaps } from '../faceProcessor/util';
|
|||
import { TinyXception } from '../xception/TinyXception';
|
||||
import { extractParams } from './extractParams';
|
||||
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
|
||||
import { AgeAndGenderPrediction, Gender, NetOutput, NetParams } from './types';
|
||||
import {
|
||||
AgeAndGenderPrediction, Gender, NetOutput, NetParams,
|
||||
} from './types';
|
||||
import { NeuralNetwork } from '../NeuralNetwork';
|
||||
import { NetInput, TNetInput, toNetInput } from '../dom/index';
|
||||
|
||||
export class AgeGenderNet extends NeuralNetwork<NetParams> {
|
||||
|
||||
private _faceFeatureExtractor: TinyXception
|
||||
|
||||
constructor(faceFeatureExtractor: TinyXception = new TinyXception(2)) {
|
||||
super('AgeGenderNet')
|
||||
this._faceFeatureExtractor = faceFeatureExtractor
|
||||
super('AgeGenderNet');
|
||||
this._faceFeatureExtractor = faceFeatureExtractor;
|
||||
}
|
||||
|
||||
public get faceFeatureExtractor(): TinyXception {
|
||||
return this._faceFeatureExtractor
|
||||
return this._faceFeatureExtractor;
|
||||
}
|
||||
|
||||
public runNet(input: NetInput | tf.Tensor4D): NetOutput {
|
||||
|
||||
const { params } = this
|
||||
const { params } = this;
|
||||
|
||||
if (!params) {
|
||||
throw new Error(`${this._name} - load model before inference`)
|
||||
throw new Error(`${this._name} - load model before inference`);
|
||||
}
|
||||
|
||||
return tf.tidy(() => {
|
||||
const bottleneckFeatures = input instanceof NetInput
|
||||
? this.faceFeatureExtractor.forwardInput(input)
|
||||
: input
|
||||
: input;
|
||||
|
||||
const pooled = tf.avgPool(bottleneckFeatures, [7, 7], [2, 2], 'valid').as2D(bottleneckFeatures.shape[0], -1)
|
||||
const age = fullyConnectedLayer(pooled, params.fc.age).as1D()
|
||||
const gender = fullyConnectedLayer(pooled, params.fc.gender)
|
||||
return { age, gender }
|
||||
})
|
||||
const pooled = tf.avgPool(bottleneckFeatures, [7, 7], [2, 2], 'valid').as2D(bottleneckFeatures.shape[0], -1);
|
||||
const age = fullyConnectedLayer(pooled, params.fc.age).as1D();
|
||||
const gender = fullyConnectedLayer(pooled, params.fc.gender);
|
||||
return { age, gender };
|
||||
});
|
||||
}
|
||||
|
||||
public forwardInput(input: NetInput | tf.Tensor4D): NetOutput {
|
||||
return tf.tidy(() => {
|
||||
const { age, gender } = this.runNet(input)
|
||||
return { age, gender: tf.softmax(gender) }
|
||||
})
|
||||
const { age, gender } = this.runNet(input);
|
||||
return { age, gender: tf.softmax(gender) };
|
||||
});
|
||||
}
|
||||
|
||||
public async forward(input: TNetInput): Promise<NetOutput> {
|
||||
return this.forwardInput(await toNetInput(input))
|
||||
return this.forwardInput(await toNetInput(input));
|
||||
}
|
||||
|
||||
public async predictAgeAndGender(input: TNetInput): Promise<AgeAndGenderPrediction | AgeAndGenderPrediction[]> {
|
||||
const netInput = await toNetInput(input)
|
||||
const out = await this.forwardInput(netInput)
|
||||
const netInput = await toNetInput(input);
|
||||
const out = await this.forwardInput(netInput);
|
||||
|
||||
const ages = tf.unstack(out.age)
|
||||
const genders = tf.unstack(out.gender)
|
||||
const ages = tf.unstack(out.age);
|
||||
const genders = tf.unstack(out.gender);
|
||||
const ageAndGenderTensors = ages.map((ageTensor, i) => ({
|
||||
ageTensor,
|
||||
genderTensor: genders[i]
|
||||
}))
|
||||
genderTensor: genders[i],
|
||||
}));
|
||||
|
||||
const predictionsByBatch = await Promise.all(
|
||||
ageAndGenderTensors.map(async ({ ageTensor, genderTensor }) => {
|
||||
const age = (await ageTensor.data())[0]
|
||||
const probMale = (await genderTensor.data())[0]
|
||||
const isMale = probMale > 0.5
|
||||
const gender = isMale ? Gender.MALE : Gender.FEMALE
|
||||
const genderProbability = isMale ? probMale : (1 - probMale)
|
||||
const age = (await ageTensor.data())[0];
|
||||
const probMale = (await genderTensor.data())[0];
|
||||
const isMale = probMale > 0.5;
|
||||
const gender = isMale ? Gender.MALE : Gender.FEMALE;
|
||||
const genderProbability = isMale ? probMale : (1 - probMale);
|
||||
|
||||
ageTensor.dispose()
|
||||
genderTensor.dispose()
|
||||
return { age, gender, genderProbability }
|
||||
})
|
||||
)
|
||||
out.age.dispose()
|
||||
out.gender.dispose()
|
||||
ageTensor.dispose();
|
||||
genderTensor.dispose();
|
||||
return { age, gender, genderProbability };
|
||||
}),
|
||||
);
|
||||
out.age.dispose();
|
||||
out.gender.dispose();
|
||||
|
||||
return netInput.isBatchInput ? predictionsByBatch as AgeAndGenderPrediction[] : predictionsByBatch[0] as AgeAndGenderPrediction
|
||||
return netInput.isBatchInput ? predictionsByBatch as AgeAndGenderPrediction[] : predictionsByBatch[0] as AgeAndGenderPrediction;
|
||||
}
|
||||
|
||||
protected getDefaultModelName(): string {
|
||||
return 'age_gender_model'
|
||||
return 'age_gender_model';
|
||||
}
|
||||
|
||||
public dispose(throwOnRedispose: boolean = true) {
|
||||
this.faceFeatureExtractor.dispose(throwOnRedispose)
|
||||
super.dispose(throwOnRedispose)
|
||||
this.faceFeatureExtractor.dispose(throwOnRedispose);
|
||||
super.dispose(throwOnRedispose);
|
||||
}
|
||||
|
||||
public loadClassifierParams(weights: Float32Array) {
|
||||
const { params, paramMappings } = this.extractClassifierParams(weights)
|
||||
this._params = params
|
||||
this._paramMappings = paramMappings
|
||||
const { params, paramMappings } = this.extractClassifierParams(weights);
|
||||
this._params = params;
|
||||
this._paramMappings = paramMappings;
|
||||
}
|
||||
|
||||
public extractClassifierParams(weights: Float32Array) {
|
||||
return extractParams(weights)
|
||||
return extractParams(weights);
|
||||
}
|
||||
|
||||
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
|
||||
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap);
|
||||
|
||||
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
|
||||
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap);
|
||||
|
||||
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap)
|
||||
|
||||
return extractParamsFromWeigthMap(classifierMap)
|
||||
return extractParamsFromWeigthMap(classifierMap);
|
||||
}
|
||||
|
||||
protected extractParams(weights: Float32Array) {
|
||||
const classifierWeightSize = (512 * 1 + 1) + (512 * 2 + 2);
|
||||
|
||||
const classifierWeightSize = (512 * 1 + 1) + (512 * 2 + 2)
|
||||
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize);
|
||||
const classifierWeights = weights.slice(weights.length - classifierWeightSize);
|
||||
|
||||
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
|
||||
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
|
||||
|
||||
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
|
||||
return this.extractClassifierParams(classifierWeights)
|
||||
this.faceFeatureExtractor.extractWeights(featureExtractorWeights);
|
||||
return this.extractClassifierParams(classifierWeights);
|
||||
}
|
||||
}
|
|
@ -2,25 +2,24 @@ import { extractFCParamsFactory, extractWeightsFactory, ParamMapping } from '../
|
|||
import { NetParams } from './types';
|
||||
|
||||
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
|
||||
|
||||
const paramMappings: ParamMapping[] = []
|
||||
const paramMappings: ParamMapping[] = [];
|
||||
|
||||
const {
|
||||
extractWeights,
|
||||
getRemainingWeights
|
||||
} = extractWeightsFactory(weights)
|
||||
getRemainingWeights,
|
||||
} = extractWeightsFactory(weights);
|
||||
|
||||
const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings)
|
||||
const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings);
|
||||
|
||||
const age = extractFCParams(512, 1, 'fc/age')
|
||||
const gender = extractFCParams(512, 2, 'fc/gender')
|
||||
const age = extractFCParams(512, 1, 'fc/age');
|
||||
const gender = extractFCParams(512, 2, 'fc/gender');
|
||||
|
||||
if (getRemainingWeights().length !== 0) {
|
||||
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
|
||||
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
|
||||
}
|
||||
|
||||
return {
|
||||
paramMappings,
|
||||
params: { fc: { age, gender } }
|
||||
}
|
||||
params: { fc: { age, gender } },
|
||||
};
|
||||
}
|
|
@ -1,30 +1,31 @@
|
|||
import * as tf from '../../dist/tfjs.esm.js';
|
||||
|
||||
import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index';
|
||||
import {
|
||||
disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping,
|
||||
} from '../common/index';
|
||||
import { NetParams } from './types';
|
||||
|
||||
export function extractParamsFromWeigthMap(
|
||||
weightMap: tf.NamedTensorMap
|
||||
weightMap: tf.NamedTensorMap,
|
||||
): { params: NetParams, paramMappings: ParamMapping[] } {
|
||||
const paramMappings: ParamMapping[] = [];
|
||||
|
||||
const paramMappings: ParamMapping[] = []
|
||||
|
||||
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
|
||||
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
|
||||
|
||||
function extractFcParams(prefix: string): FCParams {
|
||||
const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2)
|
||||
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
|
||||
return { weights, bias }
|
||||
const weights = extractWeightEntry(`${prefix}/weights`, 2);
|
||||
const bias = extractWeightEntry(`${prefix}/bias`, 1);
|
||||
return { weights, bias };
|
||||
}
|
||||
|
||||
const params = {
|
||||
fc: {
|
||||
age: extractFcParams('fc/age'),
|
||||
gender: extractFcParams('fc/gender')
|
||||
}
|
||||
}
|
||||
gender: extractFcParams('fc/gender'),
|
||||
},
|
||||
};
|
||||
|
||||
disposeUnusedWeightTensors(weightMap, paramMappings)
|
||||
disposeUnusedWeightTensors(weightMap, paramMappings);
|
||||
|
||||
return { params, paramMappings }
|
||||
return { params, paramMappings };
|
||||
}
|
|
@ -2,17 +2,20 @@ import * as tf from '../../dist/tfjs.esm.js';
|
|||
|
||||
import { FCParams } from '../common/index';
|
||||
|
||||
// eslint-disable-next-line no-shadow
|
||||
export enum Gender {
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
FEMALE = 'female',
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
MALE = 'male'
|
||||
}
|
||||
|
||||
export type AgeAndGenderPrediction = {
|
||||
age: number
|
||||
gender: Gender
|
||||
genderProbability: number
|
||||
}
|
||||
|
||||
export enum Gender {
|
||||
FEMALE = 'female',
|
||||
MALE = 'male'
|
||||
}
|
||||
|
||||
export type NetOutput = { age: tf.Tensor1D, gender: tf.Tensor2D }
|
||||
|
||||
export type NetParams = {
|
||||
|
|
|
@ -9,6 +9,8 @@ export interface IBoundingBox {
|
|||
|
||||
export class BoundingBox extends Box<BoundingBox> implements IBoundingBox {
|
||||
constructor(left: number, top: number, right: number, bottom: number, allowNegativeDimensions: boolean = false) {
|
||||
super({ left, top, right, bottom }, allowNegativeDimensions)
|
||||
super({
|
||||
left, top, right, bottom,
|
||||
}, allowNegativeDimensions);
|
||||
}
|
||||
}
|
|
@ -5,163 +5,197 @@ import { Point } from './Point';
|
|||
import { IRect } from './Rect';
|
||||
|
||||
export class Box<BoxType = any> implements IBoundingBox, IRect {
|
||||
|
||||
public static isRect(rect: any): boolean {
|
||||
return !!rect && [rect.x, rect.y, rect.width, rect.height].every(isValidNumber)
|
||||
return !!rect && [rect.x, rect.y, rect.width, rect.height].every(isValidNumber);
|
||||
}
|
||||
|
||||
public static assertIsValidBox(box: any, callee: string, allowNegativeDimensions: boolean = false) {
|
||||
if (!Box.isRect(box)) {
|
||||
throw new Error(`${callee} - invalid box: ${JSON.stringify(box)}, expected object with properties x, y, width, height`)
|
||||
throw new Error(`${callee} - invalid box: ${JSON.stringify(box)}, expected object with properties x, y, width, height`);
|
||||
}
|
||||
|
||||
if (!allowNegativeDimensions && (box.width < 0 || box.height < 0)) {
|
||||
throw new Error(`${callee} - width (${box.width}) and height (${box.height}) must be positive numbers`)
|
||||
throw new Error(`${callee} - width (${box.width}) and height (${box.height}) must be positive numbers`);
|
||||
}
|
||||
}
|
||||
|
||||
private _x: number
|
||||
|
||||
private _y: number
|
||||
|
||||
private _width: number
|
||||
|
||||
private _height: number
|
||||
|
||||
constructor(_box: IBoundingBox | IRect, allowNegativeDimensions: boolean = true) {
|
||||
const box = (_box || {}) as any
|
||||
const box = (_box || {}) as any;
|
||||
|
||||
const isBbox = [box.left, box.top, box.right, box.bottom].every(isValidNumber)
|
||||
const isRect = [box.x, box.y, box.width, box.height].every(isValidNumber)
|
||||
const isBbox = [box.left, box.top, box.right, box.bottom].every(isValidNumber);
|
||||
const isRect = [box.x, box.y, box.width, box.height].every(isValidNumber);
|
||||
|
||||
if (!isRect && !isBbox) {
|
||||
throw new Error(`Box.constructor - expected box to be IBoundingBox | IRect, instead have ${JSON.stringify(box)}`)
|
||||
throw new Error(`Box.constructor - expected box to be IBoundingBox | IRect, instead have ${JSON.stringify(box)}`);
|
||||
}
|
||||
|
||||
const [x, y, width, height] = isRect
|
||||
? [box.x, box.y, box.width, box.height]
|
||||
: [box.left, box.top, box.right - box.left, box.bottom - box.top]
|
||||
: [box.left, box.top, box.right - box.left, box.bottom - box.top];
|
||||
|
||||
Box.assertIsValidBox({ x, y, width, height }, 'Box.constructor', allowNegativeDimensions)
|
||||
Box.assertIsValidBox({
|
||||
x, y, width, height,
|
||||
}, 'Box.constructor', allowNegativeDimensions);
|
||||
|
||||
this._x = x
|
||||
this._y = y
|
||||
this._width = width
|
||||
this._height = height
|
||||
this._x = x;
|
||||
this._y = y;
|
||||
this._width = width;
|
||||
this._height = height;
|
||||
}
|
||||
|
||||
public get x(): number { return this._x }
|
||||
public get y(): number { return this._y }
|
||||
public get width(): number { return this._width }
|
||||
public get height(): number { return this._height }
|
||||
public get left(): number { return this.x }
|
||||
public get top(): number { return this.y }
|
||||
public get right(): number { return this.x + this.width }
|
||||
public get bottom(): number { return this.y + this.height }
|
||||
public get area(): number { return this.width * this.height }
|
||||
public get topLeft(): Point { return new Point(this.left, this.top) }
|
||||
public get topRight(): Point { return new Point(this.right, this.top) }
|
||||
public get bottomLeft(): Point { return new Point(this.left, this.bottom) }
|
||||
public get bottomRight(): Point { return new Point(this.right, this.bottom) }
|
||||
public get x(): number { return this._x; }
|
||||
|
||||
public get y(): number { return this._y; }
|
||||
|
||||
public get width(): number { return this._width; }
|
||||
|
||||
public get height(): number { return this._height; }
|
||||
|
||||
public get left(): number { return this.x; }
|
||||
|
||||
public get top(): number { return this.y; }
|
||||
|
||||
public get right(): number { return this.x + this.width; }
|
||||
|
||||
public get bottom(): number { return this.y + this.height; }
|
||||
|
||||
public get area(): number { return this.width * this.height; }
|
||||
|
||||
public get topLeft(): Point { return new Point(this.left, this.top); }
|
||||
|
||||
public get topRight(): Point { return new Point(this.right, this.top); }
|
||||
|
||||
public get bottomLeft(): Point { return new Point(this.left, this.bottom); }
|
||||
|
||||
public get bottomRight(): Point { return new Point(this.right, this.bottom); }
|
||||
|
||||
public round(): Box<BoxType> {
|
||||
const [x, y, width, height] = [this.x, this.y, this.width, this.height]
|
||||
.map(val => Math.round(val))
|
||||
return new Box({ x, y, width, height })
|
||||
.map((val) => Math.round(val));
|
||||
return new Box({
|
||||
x, y, width, height,
|
||||
});
|
||||
}
|
||||
|
||||
public floor(): Box<BoxType> {
|
||||
const [x, y, width, height] = [this.x, this.y, this.width, this.height]
|
||||
.map(val => Math.floor(val))
|
||||
return new Box({ x, y, width, height })
|
||||
.map((val) => Math.floor(val));
|
||||
return new Box({
|
||||
x, y, width, height,
|
||||
});
|
||||
}
|
||||
|
||||
public toSquare(): Box<BoxType> {
|
||||
let { x, y, width, height } = this
|
||||
const diff = Math.abs(width - height)
|
||||
let {
|
||||
x, y, width, height,
|
||||
} = this;
|
||||
const diff = Math.abs(width - height);
|
||||
if (width < height) {
|
||||
x -= (diff / 2)
|
||||
width += diff
|
||||
x -= (diff / 2);
|
||||
width += diff;
|
||||
}
|
||||
if (height < width) {
|
||||
y -= (diff / 2)
|
||||
height += diff
|
||||
y -= (diff / 2);
|
||||
height += diff;
|
||||
}
|
||||
|
||||
return new Box({ x, y, width, height })
|
||||
return new Box({
|
||||
x, y, width, height,
|
||||
});
|
||||
}
|
||||
|
||||
public rescale(s: IDimensions | number): Box<BoxType> {
|
||||
const scaleX = isDimensions(s) ? (s as IDimensions).width : s as number
|
||||
const scaleY = isDimensions(s) ? (s as IDimensions).height : s as number
|
||||
const scaleX = isDimensions(s) ? (s as IDimensions).width : s as number;
|
||||
const scaleY = isDimensions(s) ? (s as IDimensions).height : s as number;
|
||||
return new Box({
|
||||
x: this.x * scaleX,
|
||||
y: this.y * scaleY,
|
||||
width: this.width * scaleX,
|
||||
height: this.height * scaleY
|
||||
})
|
||||
height: this.height * scaleY,
|
||||
});
|
||||
}
|
||||
|
||||
public pad(padX: number, padY: number): Box<BoxType> {
|
||||
let [x, y, width, height] = [
|
||||
const [x, y, width, height] = [
|
||||
this.x - (padX / 2),
|
||||
this.y - (padY / 2),
|
||||
this.width + padX,
|
||||
this.height + padY
|
||||
]
|
||||
return new Box({ x, y, width, height })
|
||||
this.height + padY,
|
||||
];
|
||||
return new Box({
|
||||
x, y, width, height,
|
||||
});
|
||||
}
|
||||
|
||||
public clipAtImageBorders(imgWidth: number, imgHeight: number): Box<BoxType> {
|
||||
const { x, y, right, bottom } = this
|
||||
const clippedX = Math.max(x, 0)
|
||||
const clippedY = Math.max(y, 0)
|
||||
const {
|
||||
x, y, right, bottom,
|
||||
} = this;
|
||||
const clippedX = Math.max(x, 0);
|
||||
const clippedY = Math.max(y, 0);
|
||||
|
||||
const newWidth = right - clippedX
|
||||
const newHeight = bottom - clippedY
|
||||
const clippedWidth = Math.min(newWidth, imgWidth - clippedX)
|
||||
const clippedHeight = Math.min(newHeight, imgHeight - clippedY)
|
||||
const newWidth = right - clippedX;
|
||||
const newHeight = bottom - clippedY;
|
||||
const clippedWidth = Math.min(newWidth, imgWidth - clippedX);
|
||||
const clippedHeight = Math.min(newHeight, imgHeight - clippedY);
|
||||
|
||||
return (new Box({ x: clippedX, y: clippedY, width: clippedWidth, height: clippedHeight})).floor()
|
||||
return (new Box({
|
||||
x: clippedX, y: clippedY, width: clippedWidth, height: clippedHeight,
|
||||
})).floor();
|
||||
}
|
||||
|
||||
public shift(sx: number, sy: number): Box<BoxType> {
|
||||
const { width, height } = this
|
||||
const x = this.x + sx
|
||||
const y = this.y + sy
|
||||
const { width, height } = this;
|
||||
const x = this.x + sx;
|
||||
const y = this.y + sy;
|
||||
|
||||
return new Box({ x, y, width, height })
|
||||
return new Box({
|
||||
x, y, width, height,
|
||||
});
|
||||
}
|
||||
|
||||
public padAtBorders(imageHeight: number, imageWidth: number) {
|
||||
const w = this.width + 1
|
||||
const h = this.height + 1
|
||||
const w = this.width + 1;
|
||||
const h = this.height + 1;
|
||||
|
||||
let dx = 1
|
||||
let dy = 1
|
||||
let edx = w
|
||||
let edy = h
|
||||
const dx = 1;
|
||||
const dy = 1;
|
||||
let edx = w;
|
||||
let edy = h;
|
||||
|
||||
let x = this.left
|
||||
let y = this.top
|
||||
let ex = this.right
|
||||
let ey = this.bottom
|
||||
let x = this.left;
|
||||
let y = this.top;
|
||||
let ex = this.right;
|
||||
let ey = this.bottom;
|
||||
|
||||
if (ex > imageWidth) {
|
||||
edx = -ex + imageWidth + w
|
||||
ex = imageWidth
|
||||
edx = -ex + imageWidth + w;
|
||||
ex = imageWidth;
|
||||
}
|
||||
if (ey > imageHeight) {
|
||||
edy = -ey + imageHeight + h
|
||||
ey = imageHeight
|
||||
edy = -ey + imageHeight + h;
|
||||
ey = imageHeight;
|
||||
}
|
||||
if (x < 1) {
|
||||
edy = 2 - x
|
||||
x = 1
|
||||
edy = 2 - x;
|
||||
x = 1;
|
||||
}
|
||||
if (y < 1) {
|
||||
edy = 2 - y
|
||||
y = 1
|
||||
edy = 2 - y;
|
||||
y = 1;
|
||||
}
|
||||
|
||||
return { dy, edy, dx, edx, y, ey, x, ex, w, h }
|
||||
return {
|
||||
dy, edy, dx, edx, y, ey, x, ex, w, h,
|
||||
};
|
||||
}
|
||||
|
||||
public calibrate(region: Box) {
|
||||
|
@ -169,7 +203,7 @@ export class Box<BoxType = any> implements IBoundingBox, IRect {
|
|||
left: this.left + (region.left * this.width),
|
||||
top: this.top + (region.top * this.height),
|
||||
right: this.right + (region.right * this.width),
|
||||
bottom: this.bottom + (region.bottom * this.height)
|
||||
}).toSquare().round()
|
||||
bottom: this.bottom + (region.bottom * this.height),
|
||||
}).toSquare().round();
|
||||
}
|
||||
}
|
|
@ -6,23 +6,24 @@ export interface IDimensions {
|
|||
}
|
||||
|
||||
export class Dimensions implements IDimensions {
|
||||
|
||||
private _width: number
|
||||
|
||||
private _height: number
|
||||
|
||||
constructor(width: number, height: number) {
|
||||
if (!isValidNumber(width) || !isValidNumber(height)) {
|
||||
throw new Error(`Dimensions.constructor - expected width and height to be valid numbers, instead have ${JSON.stringify({ width, height })}`)
|
||||
throw new Error(`Dimensions.constructor - expected width and height to be valid numbers, instead have ${JSON.stringify({ width, height })}`);
|
||||
}
|
||||
|
||||
this._width = width
|
||||
this._height = height
|
||||
this._width = width;
|
||||
this._height = height;
|
||||
}
|
||||
|
||||
public get width(): number { return this._width }
|
||||
public get height(): number { return this._height }
|
||||
public get width(): number { return this._width; }
|
||||
|
||||
public get height(): number { return this._height; }
|
||||
|
||||
public reverse(): Dimensions {
|
||||
return new Dimensions(1 / this.width, 1 / this.height)
|
||||
return new Dimensions(1 / this.width, 1 / this.height);
|
||||
}
|
||||
}
|
|
@ -12,13 +12,13 @@ export class FaceDetection extends ObjectDetection implements IFaceDetecion {
|
|||
constructor(
|
||||
score: number,
|
||||
relativeBox: Rect,
|
||||
imageDims: IDimensions
|
||||
imageDims: IDimensions,
|
||||
) {
|
||||
super(score, score, '', relativeBox, imageDims)
|
||||
super(score, score, '', relativeBox, imageDims);
|
||||
}
|
||||
|
||||
public forSize(width: number, height: number): FaceDetection {
|
||||
const { score, relativeBox, imageDims } = super.forSize(width, height)
|
||||
return new FaceDetection(score, relativeBox, imageDims)
|
||||
const { score, relativeBox, imageDims } = super.forSize(width, height);
|
||||
return new FaceDetection(score, relativeBox, imageDims);
|
||||
}
|
||||
}
|
|
@ -8,9 +8,9 @@ import { Point } from './Point';
|
|||
import { IRect, Rect } from './Rect';
|
||||
|
||||
// face alignment constants
|
||||
const relX = 0.5
|
||||
const relY = 0.43
|
||||
const relScale = 0.45
|
||||
const relX = 0.5;
|
||||
const relY = 0.43;
|
||||
const relScale = 0.45;
|
||||
|
||||
export interface IFaceLandmarks {
|
||||
positions: Point[]
|
||||
|
@ -19,49 +19,55 @@ export interface IFaceLandmarks {
|
|||
|
||||
export class FaceLandmarks implements IFaceLandmarks {
|
||||
protected _shift: Point
|
||||
|
||||
protected _positions: Point[]
|
||||
|
||||
protected _imgDims: Dimensions
|
||||
|
||||
constructor(
|
||||
relativeFaceLandmarkPositions: Point[],
|
||||
imgDims: IDimensions,
|
||||
shift: Point = new Point(0, 0)
|
||||
shift: Point = new Point(0, 0),
|
||||
) {
|
||||
const { width, height } = imgDims
|
||||
this._imgDims = new Dimensions(width, height)
|
||||
this._shift = shift
|
||||
const { width, height } = imgDims;
|
||||
this._imgDims = new Dimensions(width, height);
|
||||
this._shift = shift;
|
||||
this._positions = relativeFaceLandmarkPositions.map(
|
||||
pt => pt.mul(new Point(width, height)).add(shift)
|
||||
)
|
||||
(pt) => pt.mul(new Point(width, height)).add(shift),
|
||||
);
|
||||
}
|
||||
|
||||
public get shift(): Point { return new Point(this._shift.x, this._shift.y) }
|
||||
public get imageWidth(): number { return this._imgDims.width }
|
||||
public get imageHeight(): number { return this._imgDims.height }
|
||||
public get positions(): Point[] { return this._positions }
|
||||
public get shift(): Point { return new Point(this._shift.x, this._shift.y); }
|
||||
|
||||
public get imageWidth(): number { return this._imgDims.width; }
|
||||
|
||||
public get imageHeight(): number { return this._imgDims.height; }
|
||||
|
||||
public get positions(): Point[] { return this._positions; }
|
||||
|
||||
public get relativePositions(): Point[] {
|
||||
return this._positions.map(
|
||||
pt => pt.sub(this._shift).div(new Point(this.imageWidth, this.imageHeight))
|
||||
)
|
||||
(pt) => pt.sub(this._shift).div(new Point(this.imageWidth, this.imageHeight)),
|
||||
);
|
||||
}
|
||||
|
||||
public forSize<T extends FaceLandmarks>(width: number, height: number): T {
|
||||
return new (this.constructor as any)(
|
||||
this.relativePositions,
|
||||
{ width, height }
|
||||
)
|
||||
{ width, height },
|
||||
);
|
||||
}
|
||||
|
||||
public shiftBy<T extends FaceLandmarks>(x: number, y: number): T {
|
||||
return new (this.constructor as any)(
|
||||
this.relativePositions,
|
||||
this._imgDims,
|
||||
new Point(x, y)
|
||||
)
|
||||
new Point(x, y),
|
||||
);
|
||||
}
|
||||
|
||||
public shiftByPoint<T extends FaceLandmarks>(pt: Point): T {
|
||||
return this.shiftBy(pt.x, pt.y)
|
||||
return this.shiftBy(pt.x, pt.y);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -77,49 +83,48 @@ export class FaceLandmarks implements IFaceLandmarks {
|
|||
*/
|
||||
public align(
|
||||
detection?: FaceDetection | IRect | IBoundingBox | null,
|
||||
options: { useDlibAlignment?: boolean, minBoxPadding?: number } = { }
|
||||
options: { useDlibAlignment?: boolean, minBoxPadding?: number } = { },
|
||||
): Box {
|
||||
if (detection) {
|
||||
const box = detection instanceof FaceDetection
|
||||
? detection.box.floor()
|
||||
: new Box(detection)
|
||||
: new Box(detection);
|
||||
|
||||
return this.shiftBy(box.x, box.y).align(null, options)
|
||||
return this.shiftBy(box.x, box.y).align(null, options);
|
||||
}
|
||||
|
||||
const { useDlibAlignment, minBoxPadding } = Object.assign({}, { useDlibAlignment: false, minBoxPadding: 0.2 }, options)
|
||||
const { useDlibAlignment, minBoxPadding } = { useDlibAlignment: false, minBoxPadding: 0.2, ...options };
|
||||
|
||||
if (useDlibAlignment) {
|
||||
return this.alignDlib()
|
||||
return this.alignDlib();
|
||||
}
|
||||
|
||||
return this.alignMinBbox(minBoxPadding)
|
||||
return this.alignMinBbox(minBoxPadding);
|
||||
}
|
||||
|
||||
private alignDlib(): Box {
|
||||
const centers = this.getRefPointsForAlignment();
|
||||
|
||||
const centers = this.getRefPointsForAlignment()
|
||||
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers;
|
||||
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude();
|
||||
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2;
|
||||
|
||||
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
|
||||
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
|
||||
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
|
||||
const size = Math.floor(eyeToMouthDist / relScale);
|
||||
|
||||
const size = Math.floor(eyeToMouthDist / relScale)
|
||||
|
||||
const refPoint = getCenterPoint(centers)
|
||||
const refPoint = getCenterPoint(centers);
|
||||
// TODO: pad in case rectangle is out of image bounds
|
||||
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)))
|
||||
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)))
|
||||
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)));
|
||||
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)));
|
||||
|
||||
return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y))
|
||||
return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y));
|
||||
}
|
||||
|
||||
private alignMinBbox(padding: number): Box {
|
||||
const box = minBbox(this.positions)
|
||||
return box.pad(box.width * padding, box.height * padding)
|
||||
const box = minBbox(this.positions);
|
||||
return box.pad(box.width * padding, box.height * padding);
|
||||
}
|
||||
|
||||
protected getRefPointsForAlignment(): Point[] {
|
||||
throw new Error('getRefPointsForAlignment not implemented by base class')
|
||||
throw new Error('getRefPointsForAlignment not implemented by base class');
|
||||
}
|
||||
}
|
|
@ -2,15 +2,13 @@ import { getCenterPoint } from '../utils/index';
|
|||
import { FaceLandmarks } from './FaceLandmarks';
|
||||
import { Point } from './Point';
|
||||
|
||||
|
||||
export class FaceLandmarks5 extends FaceLandmarks {
|
||||
|
||||
protected getRefPointsForAlignment(): Point[] {
|
||||
const pts = this.positions
|
||||
const pts = this.positions;
|
||||
return [
|
||||
pts[0],
|
||||
pts[1],
|
||||
getCenterPoint([pts[3], pts[4]])
|
||||
]
|
||||
getCenterPoint([pts[3], pts[4]]),
|
||||
];
|
||||
}
|
||||
}
|
|
@ -4,38 +4,38 @@ import { Point } from './Point';
|
|||
|
||||
export class FaceLandmarks68 extends FaceLandmarks {
|
||||
public getJawOutline(): Point[] {
|
||||
return this.positions.slice(0, 17)
|
||||
return this.positions.slice(0, 17);
|
||||
}
|
||||
|
||||
public getLeftEyeBrow(): Point[] {
|
||||
return this.positions.slice(17, 22)
|
||||
return this.positions.slice(17, 22);
|
||||
}
|
||||
|
||||
public getRightEyeBrow(): Point[] {
|
||||
return this.positions.slice(22, 27)
|
||||
return this.positions.slice(22, 27);
|
||||
}
|
||||
|
||||
public getNose(): Point[] {
|
||||
return this.positions.slice(27, 36)
|
||||
return this.positions.slice(27, 36);
|
||||
}
|
||||
|
||||
public getLeftEye(): Point[] {
|
||||
return this.positions.slice(36, 42)
|
||||
return this.positions.slice(36, 42);
|
||||
}
|
||||
|
||||
public getRightEye(): Point[] {
|
||||
return this.positions.slice(42, 48)
|
||||
return this.positions.slice(42, 48);
|
||||
}
|
||||
|
||||
public getMouth(): Point[] {
|
||||
return this.positions.slice(48, 68)
|
||||
return this.positions.slice(48, 68);
|
||||
}
|
||||
|
||||
protected getRefPointsForAlignment(): Point[] {
|
||||
return [
|
||||
this.getLeftEye(),
|
||||
this.getRightEye(),
|
||||
this.getMouth()
|
||||
].map(getCenterPoint)
|
||||
this.getMouth(),
|
||||
].map(getCenterPoint);
|
||||
}
|
||||
}
|
|
@ -7,17 +7,19 @@ export interface IFaceMatch {
|
|||
|
||||
export class FaceMatch implements IFaceMatch {
|
||||
private _label: string
|
||||
|
||||
private _distance: number
|
||||
|
||||
constructor(label: string, distance: number) {
|
||||
this._label = label
|
||||
this._distance = distance
|
||||
this._label = label;
|
||||
this._distance = distance;
|
||||
}
|
||||
|
||||
public get label(): string { return this._label }
|
||||
public get distance(): number { return this._distance }
|
||||
public get label(): string { return this._label; }
|
||||
|
||||
public get distance(): number { return this._distance; }
|
||||
|
||||
public toString(withDistance: boolean = true): string {
|
||||
return `${this.label}${withDistance ? ` (${round(this.distance)})` : ''}`
|
||||
return `${this.label}${withDistance ? ` (${round(this.distance)})` : ''}`;
|
||||
}
|
||||
}
|
|
@ -4,22 +4,20 @@ import { Box } from './Box';
|
|||
import { IRect } from './Rect';
|
||||
|
||||
export class LabeledBox extends Box<LabeledBox> {
|
||||
|
||||
public static assertIsValidLabeledBox(box: any, callee: string) {
|
||||
Box.assertIsValidBox(box, callee)
|
||||
Box.assertIsValidBox(box, callee);
|
||||
|
||||
if (!isValidNumber(box.label)) {
|
||||
throw new Error(`${callee} - expected property label (${box.label}) to be a number`)
|
||||
throw new Error(`${callee} - expected property label (${box.label}) to be a number`);
|
||||
}
|
||||
}
|
||||
|
||||
private _label: number
|
||||
|
||||
constructor(box: IBoundingBox | IRect | any, label: number) {
|
||||
super(box)
|
||||
this._label = label
|
||||
super(box);
|
||||
this._label = label;
|
||||
}
|
||||
|
||||
public get label(): number { return this._label }
|
||||
|
||||
public get label(): number { return this._label; }
|
||||
}
|
|
@ -1,35 +1,34 @@
|
|||
export class LabeledFaceDescriptors {
|
||||
private _label: string
|
||||
|
||||
private _descriptors: Float32Array[]
|
||||
|
||||
constructor(label: string, descriptors: Float32Array[]) {
|
||||
if (!(typeof label === 'string')) {
|
||||
throw new Error('LabeledFaceDescriptors - constructor expected label to be a string')
|
||||
throw new Error('LabeledFaceDescriptors - constructor expected label to be a string');
|
||||
}
|
||||
|
||||
if (!Array.isArray(descriptors) || descriptors.some(desc => !(desc instanceof Float32Array))) {
|
||||
throw new Error('LabeledFaceDescriptors - constructor expected descriptors to be an array of Float32Array')
|
||||
if (!Array.isArray(descriptors) || descriptors.some((desc) => !(desc instanceof Float32Array))) {
|
||||
throw new Error('LabeledFaceDescriptors - constructor expected descriptors to be an array of Float32Array');
|
||||
}
|
||||
|
||||
this._label = label
|
||||
this._descriptors = descriptors
|
||||
this._label = label;
|
||||
this._descriptors = descriptors;
|
||||
}
|
||||
|
||||
public get label(): string { return this._label }
|
||||
public get descriptors(): Float32Array[] { return this._descriptors }
|
||||
public get label(): string { return this._label; }
|
||||
|
||||
public get descriptors(): Float32Array[] { return this._descriptors; }
|
||||
|
||||
public toJSON(): any {
|
||||
return {
|
||||
label: this.label,
|
||||
descriptors: this.descriptors.map((d) => Array.from(d))
|
||||
descriptors: this.descriptors.map((d) => Array.from(d)),
|
||||
};
|
||||
}
|
||||
|
||||
public static fromJSON(json: any): LabeledFaceDescriptors {
|
||||
const descriptors = json.descriptors.map((d: any) => {
|
||||
return new Float32Array(d);
|
||||
});
|
||||
const descriptors = json.descriptors.map((d: any) => new Float32Array(d));
|
||||
return new LabeledFaceDescriptors(json.label, descriptors);
|
||||
}
|
||||
|
||||
}
|
|
@ -4,9 +4,13 @@ import { IRect, Rect } from './Rect';
|
|||
|
||||
export class ObjectDetection {
|
||||
private _score: number
|
||||
|
||||
private _classScore: number
|
||||
|
||||
private _className: string
|
||||
|
||||
private _box: Rect
|
||||
|
||||
private _imageDims: Dimensions
|
||||
|
||||
constructor(
|
||||
|
@ -14,23 +18,30 @@ export class ObjectDetection {
|
|||
classScore: number,
|
||||
className: string,
|
||||
relativeBox: IRect,
|
||||
imageDims: IDimensions
|
||||
imageDims: IDimensions,
|
||||
) {
|
||||
this._imageDims = new Dimensions(imageDims.width, imageDims.height)
|
||||
this._score = score
|
||||
this._classScore = classScore
|
||||
this._className = className
|
||||
this._box = new Box(relativeBox).rescale(this._imageDims)
|
||||
this._imageDims = new Dimensions(imageDims.width, imageDims.height);
|
||||
this._score = score;
|
||||
this._classScore = classScore;
|
||||
this._className = className;
|
||||
this._box = new Box(relativeBox).rescale(this._imageDims);
|
||||
}
|
||||
|
||||
public get score(): number { return this._score }
|
||||
public get classScore(): number { return this._classScore }
|
||||
public get className(): string { return this._className }
|
||||
public get box(): Box { return this._box }
|
||||
public get imageDims(): Dimensions { return this._imageDims }
|
||||
public get imageWidth(): number { return this.imageDims.width }
|
||||
public get imageHeight(): number { return this.imageDims.height }
|
||||
public get relativeBox(): Box { return new Box(this._box).rescale(this.imageDims.reverse()) }
|
||||
public get score(): number { return this._score; }
|
||||
|
||||
public get classScore(): number { return this._classScore; }
|
||||
|
||||
public get className(): string { return this._className; }
|
||||
|
||||
public get box(): Box { return this._box; }
|
||||
|
||||
public get imageDims(): Dimensions { return this._imageDims; }
|
||||
|
||||
public get imageWidth(): number { return this.imageDims.width; }
|
||||
|
||||
public get imageHeight(): number { return this.imageDims.height; }
|
||||
|
||||
public get relativeBox(): Box { return new Box(this._box).rescale(this.imageDims.reverse()); }
|
||||
|
||||
public forSize(width: number, height: number): ObjectDetection {
|
||||
return new ObjectDetection(
|
||||
|
@ -38,7 +49,7 @@ export class ObjectDetection {
|
|||
this.classScore,
|
||||
this.className,
|
||||
this.relativeBox,
|
||||
{ width, height}
|
||||
)
|
||||
{ width, height },
|
||||
);
|
||||
}
|
||||
}
|
|
@ -5,41 +5,43 @@ export interface IPoint {
|
|||
|
||||
export class Point implements IPoint {
|
||||
private _x: number
|
||||
|
||||
private _y: number
|
||||
|
||||
constructor(x: number, y: number) {
|
||||
this._x = x
|
||||
this._y = y
|
||||
this._x = x;
|
||||
this._y = y;
|
||||
}
|
||||
|
||||
get x(): number { return this._x }
|
||||
get y(): number { return this._y }
|
||||
get x(): number { return this._x; }
|
||||
|
||||
get y(): number { return this._y; }
|
||||
|
||||
public add(pt: IPoint): Point {
|
||||
return new Point(this.x + pt.x, this.y + pt.y)
|
||||
return new Point(this.x + pt.x, this.y + pt.y);
|
||||
}
|
||||
|
||||
public sub(pt: IPoint): Point {
|
||||
return new Point(this.x - pt.x, this.y - pt.y)
|
||||
return new Point(this.x - pt.x, this.y - pt.y);
|
||||
}
|
||||
|
||||
public mul(pt: IPoint): Point {
|
||||
return new Point(this.x * pt.x, this.y * pt.y)
|
||||
return new Point(this.x * pt.x, this.y * pt.y);
|
||||
}
|
||||
|
||||
public div(pt: IPoint): Point {
|
||||
return new Point(this.x / pt.x, this.y / pt.y)
|
||||
return new Point(this.x / pt.x, this.y / pt.y);
|
||||
}
|
||||
|
||||
public abs(): Point {
|
||||
return new Point(Math.abs(this.x), Math.abs(this.y))
|
||||
return new Point(Math.abs(this.x), Math.abs(this.y));
|
||||
}
|
||||
|
||||
public magnitude(): number {
|
||||
return Math.sqrt(Math.pow(this.x, 2) + Math.pow(this.y, 2))
|
||||
return Math.sqrt((this.x ** 2) + (this.y ** 2));
|
||||
}
|
||||
|
||||
public floor(): Point {
|
||||
return new Point(Math.floor(this.x), Math.floor(this.y))
|
||||
return new Point(Math.floor(this.x), Math.floor(this.y));
|
||||
}
|
||||
}
|
|
@ -4,28 +4,28 @@ import { LabeledBox } from './LabeledBox';
|
|||
import { IRect } from './Rect';
|
||||
|
||||
export class PredictedBox extends LabeledBox {
|
||||
|
||||
public static assertIsValidPredictedBox(box: any, callee: string) {
|
||||
LabeledBox.assertIsValidLabeledBox(box, callee)
|
||||
LabeledBox.assertIsValidLabeledBox(box, callee);
|
||||
|
||||
if (
|
||||
!isValidProbablitiy(box.score)
|
||||
|| !isValidProbablitiy(box.classScore)
|
||||
) {
|
||||
throw new Error(`${callee} - expected properties score (${box.score}) and (${box.classScore}) to be a number between [0, 1]`)
|
||||
throw new Error(`${callee} - expected properties score (${box.score}) and (${box.classScore}) to be a number between [0, 1]`);
|
||||
}
|
||||
}
|
||||
|
||||
private _score: number
|
||||
|
||||
private _classScore: number
|
||||
|
||||
constructor(box: IBoundingBox | IRect | any, label: number, score: number, classScore: number) {
|
||||
super(box, label)
|
||||
this._score = score
|
||||
this._classScore = classScore
|
||||
super(box, label);
|
||||
this._score = score;
|
||||
this._classScore = classScore;
|
||||
}
|
||||
|
||||
public get score(): number { return this._score }
|
||||
public get classScore(): number { return this._classScore }
|
||||
public get score(): number { return this._score; }
|
||||
|
||||
public get classScore(): number { return this._classScore; }
|
||||
}
|
|
@ -9,6 +9,8 @@ export interface IRect {
|
|||
|
||||
export class Rect extends Box<Rect> implements IRect {
|
||||
constructor(x: number, y: number, width: number, height: number, allowNegativeDimensions: boolean = false) {
|
||||
super({ x, y, width, height }, allowNegativeDimensions)
|
||||
super({
|
||||
x, y, width, height,
|
||||
}, allowNegativeDimensions);
|
||||
}
|
||||
}
|
|
@ -1,14 +1,14 @@
|
|||
export * from './BoundingBox'
|
||||
export * from './Box'
|
||||
export * from './Dimensions'
|
||||
export * from './BoundingBox';
|
||||
export * from './Box';
|
||||
export * from './Dimensions';
|
||||
export * from './FaceDetection';
|
||||
export * from './FaceLandmarks';
|
||||
export * from './FaceLandmarks5';
|
||||
export * from './FaceLandmarks68';
|
||||
export * from './FaceMatch';
|
||||
export * from './LabeledBox'
|
||||
export * from './LabeledBox';
|
||||
export * from './LabeledFaceDescriptors';
|
||||
export * from './ObjectDetection'
|
||||
export * from './Point'
|
||||
export * from './PredictedBox'
|
||||
export * from './Rect'
|
||||
export * from './ObjectDetection';
|
||||
export * from './Point';
|
||||
export * from './PredictedBox';
|
||||
export * from './Rect';
|
||||
|
|
|
@ -6,14 +6,14 @@ export function convLayer(
|
|||
x: tf.Tensor4D,
|
||||
params: ConvParams,
|
||||
padding: 'valid' | 'same' = 'same',
|
||||
withRelu: boolean = false
|
||||
withRelu: boolean = false,
|
||||
): tf.Tensor4D {
|
||||
return tf.tidy(() => {
|
||||
const out = tf.add(
|
||||
tf.conv2d(x, params.filters, [1, 1], padding),
|
||||
params.bias
|
||||
) as tf.Tensor4D
|
||||
params.bias,
|
||||
) as tf.Tensor4D;
|
||||
|
||||
return withRelu ? tf.relu(out) : out
|
||||
})
|
||||
return withRelu ? tf.relu(out) : out;
|
||||
});
|
||||
}
|
|
@ -5,11 +5,11 @@ import { SeparableConvParams } from './types';
|
|||
export function depthwiseSeparableConv(
|
||||
x: tf.Tensor4D,
|
||||
params: SeparableConvParams,
|
||||
stride: [number, number]
|
||||
stride: [number, number],
|
||||
): tf.Tensor4D {
|
||||
return tf.tidy(() => {
|
||||
let out = tf.separableConv2d(x, params.depthwise_filter, params.pointwise_filter, stride, 'same')
|
||||
out = tf.add(out, params.bias)
|
||||
return out
|
||||
})
|
||||
let out = tf.separableConv2d(x, params.depthwise_filter, params.pointwise_filter, stride, 'same');
|
||||
out = tf.add(out, params.bias);
|
||||
return out;
|
||||
});
|
||||
}
|
|
@ -1,9 +1,9 @@
|
|||
import { ParamMapping } from './types';
|
||||
|
||||
export function disposeUnusedWeightTensors(weightMap: any, paramMappings: ParamMapping[]) {
|
||||
Object.keys(weightMap).forEach(path => {
|
||||
if (!paramMappings.some(pm => pm.originalPath === path)) {
|
||||
weightMap[path].dispose()
|
||||
Object.keys(weightMap).forEach((path) => {
|
||||
if (!paramMappings.some((pm) => pm.originalPath === path)) {
|
||||
weightMap[path].dispose();
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
|
|
@ -4,28 +4,25 @@ import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
|
|||
|
||||
export function extractConvParamsFactory(
|
||||
extractWeights: ExtractWeightsFunction,
|
||||
paramMappings: ParamMapping[]
|
||||
paramMappings: ParamMapping[],
|
||||
) {
|
||||
|
||||
return function(
|
||||
return (
|
||||
channelsIn: number,
|
||||
channelsOut: number,
|
||||
filterSize: number,
|
||||
mappedPrefix: string
|
||||
): ConvParams {
|
||||
|
||||
mappedPrefix: string,
|
||||
): ConvParams => {
|
||||
const filters = tf.tensor4d(
|
||||
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
|
||||
[filterSize, filterSize, channelsIn, channelsOut]
|
||||
)
|
||||
const bias = tf.tensor1d(extractWeights(channelsOut))
|
||||
[filterSize, filterSize, channelsIn, channelsOut],
|
||||
);
|
||||
const bias = tf.tensor1d(extractWeights(channelsOut));
|
||||
|
||||
paramMappings.push(
|
||||
{ paramPath: `${mappedPrefix}/filters` },
|
||||
{ paramPath: `${mappedPrefix}/bias` }
|
||||
)
|
||||
|
||||
return { filters, bias }
|
||||
}
|
||||
{ paramPath: `${mappedPrefix}/bias` },
|
||||
);
|
||||
|
||||
return { filters, bias };
|
||||
};
|
||||
}
|
||||
|
|
|
@ -2,30 +2,26 @@ import * as tf from '../../dist/tfjs.esm.js';
|
|||
|
||||
import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
|
||||
|
||||
|
||||
export function extractFCParamsFactory(
|
||||
extractWeights: ExtractWeightsFunction,
|
||||
paramMappings: ParamMapping[]
|
||||
paramMappings: ParamMapping[],
|
||||
) {
|
||||
|
||||
return function(
|
||||
return (
|
||||
channelsIn: number,
|
||||
channelsOut: number,
|
||||
mappedPrefix: string
|
||||
): FCParams {
|
||||
|
||||
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
|
||||
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
|
||||
mappedPrefix: string,
|
||||
): FCParams => {
|
||||
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]);
|
||||
const fc_bias = tf.tensor1d(extractWeights(channelsOut));
|
||||
|
||||
paramMappings.push(
|
||||
{ paramPath: `${mappedPrefix}/weights` },
|
||||
{ paramPath: `${mappedPrefix}/bias` }
|
||||
)
|
||||
{ paramPath: `${mappedPrefix}/bias` },
|
||||
);
|
||||
|
||||
return {
|
||||
weights: fc_weights,
|
||||
bias: fc_bias
|
||||
}
|
||||
}
|
||||
|
||||
bias: fc_bias,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
|
|
@ -4,43 +4,40 @@ import { ExtractWeightsFunction, ParamMapping, SeparableConvParams } from './typ
|
|||
|
||||
export function extractSeparableConvParamsFactory(
|
||||
extractWeights: ExtractWeightsFunction,
|
||||
paramMappings: ParamMapping[]
|
||||
paramMappings: ParamMapping[],
|
||||
) {
|
||||
|
||||
return function(channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams {
|
||||
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1])
|
||||
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
|
||||
const bias = tf.tensor1d(extractWeights(channelsOut))
|
||||
return (channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams => {
|
||||
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]);
|
||||
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]);
|
||||
const bias = tf.tensor1d(extractWeights(channelsOut));
|
||||
|
||||
paramMappings.push(
|
||||
{ paramPath: `${mappedPrefix}/depthwise_filter` },
|
||||
{ paramPath: `${mappedPrefix}/pointwise_filter` },
|
||||
{ paramPath: `${mappedPrefix}/bias` }
|
||||
)
|
||||
{ paramPath: `${mappedPrefix}/bias` },
|
||||
);
|
||||
|
||||
return new SeparableConvParams(
|
||||
depthwise_filter,
|
||||
pointwise_filter,
|
||||
bias
|
||||
)
|
||||
}
|
||||
|
||||
bias,
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
export function loadSeparableConvParamsFactory(
|
||||
extractWeightEntry: <T>(originalPath: string, paramRank: number) => T
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
extractWeightEntry: <T>(originalPath: string, paramRank: number) => T,
|
||||
) {
|
||||
|
||||
return function (prefix: string): SeparableConvParams {
|
||||
const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4)
|
||||
const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4)
|
||||
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
|
||||
return (prefix: string): SeparableConvParams => {
|
||||
const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4);
|
||||
const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4);
|
||||
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
|
||||
|
||||
return new SeparableConvParams(
|
||||
depthwise_filter,
|
||||
pointwise_filter,
|
||||
bias
|
||||
)
|
||||
}
|
||||
|
||||
bias,
|
||||
);
|
||||
};
|
||||
}
|
||||
|
|
|
@ -2,19 +2,17 @@ import { isTensor } from '../utils/index';
|
|||
import { ParamMapping } from './types';
|
||||
|
||||
export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) {
|
||||
|
||||
return function<T> (originalPath: string, paramRank: number, mappedPath?: string): T {
|
||||
const tensor = weightMap[originalPath]
|
||||
return (originalPath: string, paramRank: number, mappedPath?: string) => {
|
||||
const tensor = weightMap[originalPath];
|
||||
|
||||
if (!isTensor(tensor, paramRank)) {
|
||||
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`)
|
||||
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`);
|
||||
}
|
||||
|
||||
paramMappings.push(
|
||||
{ originalPath, paramPath: mappedPath || originalPath }
|
||||
)
|
||||
|
||||
return tensor
|
||||
}
|
||||
{ originalPath, paramPath: mappedPath || originalPath },
|
||||
);
|
||||
|
||||
return tensor;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
export function extractWeightsFactory(weights: Float32Array) {
|
||||
let remainingWeights = weights
|
||||
let remainingWeights = weights;
|
||||
|
||||
function extractWeights(numWeights: number): Float32Array {
|
||||
const ret = remainingWeights.slice(0, numWeights)
|
||||
remainingWeights = remainingWeights.slice(numWeights)
|
||||
return ret
|
||||
const ret = remainingWeights.slice(0, numWeights);
|
||||
remainingWeights = remainingWeights.slice(numWeights);
|
||||
return ret;
|
||||
}
|
||||
|
||||
function getRemainingWeights(): Float32Array {
|
||||
return remainingWeights
|
||||
return remainingWeights;
|
||||
}
|
||||
|
||||
return {
|
||||
extractWeights,
|
||||
getRemainingWeights
|
||||
}
|
||||
getRemainingWeights,
|
||||
};
|
||||
}
|
|
@ -4,12 +4,10 @@ import { FCParams } from './types';
|
|||
|
||||
export function fullyConnectedLayer(
|
||||
x: tf.Tensor2D,
|
||||
params: FCParams
|
||||
params: FCParams,
|
||||
): tf.Tensor2D {
|
||||
return tf.tidy(() =>
|
||||
tf.add(
|
||||
return tf.tidy(() => tf.add(
|
||||
tf.matMul(x, params.weights),
|
||||
params.bias
|
||||
)
|
||||
)
|
||||
params.bias,
|
||||
));
|
||||
}
|
|
@ -1,33 +1,34 @@
|
|||
export function getModelUris(uri: string | undefined, defaultModelName: string) {
|
||||
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`
|
||||
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`;
|
||||
|
||||
if (!uri) {
|
||||
return {
|
||||
modelBaseUri: '',
|
||||
manifestUri: defaultManifestFilename
|
||||
}
|
||||
manifestUri: defaultManifestFilename,
|
||||
};
|
||||
}
|
||||
|
||||
if (uri === '/') {
|
||||
return {
|
||||
modelBaseUri: '/',
|
||||
manifestUri: `/${defaultManifestFilename}`
|
||||
}
|
||||
manifestUri: `/${defaultManifestFilename}`,
|
||||
};
|
||||
}
|
||||
// eslint-disable-next-line no-nested-ternary
|
||||
const protocol = uri.startsWith('http://') ? 'http://' : uri.startsWith('https://') ? 'https://' : '';
|
||||
uri = uri.replace(protocol, '');
|
||||
|
||||
const parts = uri.split('/').filter(s => s)
|
||||
const parts = uri.split('/').filter((s) => s);
|
||||
|
||||
const manifestFile = uri.endsWith('.json')
|
||||
? parts[parts.length - 1]
|
||||
: defaultManifestFilename
|
||||
: defaultManifestFilename;
|
||||
|
||||
let modelBaseUri = protocol + (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/')
|
||||
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri
|
||||
let modelBaseUri = protocol + (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/');
|
||||
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri;
|
||||
|
||||
return {
|
||||
modelBaseUri,
|
||||
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`
|
||||
}
|
||||
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`,
|
||||
};
|
||||
}
|
|
@ -1,10 +1,10 @@
|
|||
export * from './convLayer'
|
||||
export * from './depthwiseSeparableConv'
|
||||
export * from './disposeUnusedWeightTensors'
|
||||
export * from './extractConvParamsFactory'
|
||||
export * from './extractFCParamsFactory'
|
||||
export * from './extractSeparableConvParamsFactory'
|
||||
export * from './extractWeightEntryFactory'
|
||||
export * from './extractWeightsFactory'
|
||||
export * from './getModelUris'
|
||||
export * from './types'
|
||||
export * from './convLayer';
|
||||
export * from './depthwiseSeparableConv';
|
||||
export * from './disposeUnusedWeightTensors';
|
||||
export * from './extractConvParamsFactory';
|
||||
export * from './extractFCParamsFactory';
|
||||
export * from './extractSeparableConvParamsFactory';
|
||||
export * from './extractWeightEntryFactory';
|
||||
export * from './extractWeightsFactory';
|
||||
export * from './getModelUris';
|
||||
export * from './types';
|
||||
|
|
|
@ -2,11 +2,12 @@ import * as tf from '../../dist/tfjs.esm.js';
|
|||
|
||||
import { ConvParams } from './types';
|
||||
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
export function loadConvParamsFactory(extractWeightEntry: <T>(originalPath: string, paramRank: number) => T) {
|
||||
return function(prefix: string): ConvParams {
|
||||
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4)
|
||||
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
|
||||
return (prefix: string): ConvParams => {
|
||||
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4);
|
||||
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
|
||||
|
||||
return { filters, bias }
|
||||
}
|
||||
return { filters, bias };
|
||||
};
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
import * as tf from '../../dist/tfjs.esm.js';
|
||||
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
export type ExtractWeightsFunction = (numWeights: number) => Float32Array
|
||||
|
||||
export type ParamMapping = {
|
||||
|
@ -18,9 +19,14 @@ export type FCParams = {
|
|||
}
|
||||
|
||||
export class SeparableConvParams {
|
||||
// eslint-disable-next-line no-useless-constructor
|
||||
constructor(
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
public depthwise_filter: tf.Tensor4D,
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
public pointwise_filter: tf.Tensor4D,
|
||||
public bias: tf.Tensor1D
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
public bias: tf.Tensor1D,
|
||||
// eslint-disable-next-line no-empty-function
|
||||
) {}
|
||||
}
|
|
@ -3,110 +3,115 @@ import * as tf from '../../dist/tfjs.esm.js';
|
|||
import { Dimensions } from '../classes/Dimensions';
|
||||
import { env } from '../env/index';
|
||||
import { padToSquare } from '../ops/padToSquare';
|
||||
import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils/index';
|
||||
import {
|
||||
computeReshapedDimensions, isTensor3D, isTensor4D, range,
|
||||
} from '../utils/index';
|
||||
import { createCanvasFromMedia } from './createCanvas';
|
||||
import { imageToSquare } from './imageToSquare';
|
||||
import { TResolvedNetInput } from './types';
|
||||
|
||||
export class NetInput {
|
||||
private _imageTensors: Array<tf.Tensor3D | tf.Tensor4D> = []
|
||||
|
||||
private _canvases: HTMLCanvasElement[] = []
|
||||
|
||||
private _batchSize: number
|
||||
|
||||
private _treatAsBatchInput: boolean = false
|
||||
|
||||
private _inputDimensions: number[][] = []
|
||||
|
||||
private _inputSize: number
|
||||
|
||||
constructor(
|
||||
inputs: Array<TResolvedNetInput>,
|
||||
treatAsBatchInput: boolean = false
|
||||
treatAsBatchInput: boolean = false,
|
||||
) {
|
||||
if (!Array.isArray(inputs)) {
|
||||
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`)
|
||||
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`);
|
||||
}
|
||||
|
||||
this._treatAsBatchInput = treatAsBatchInput
|
||||
this._batchSize = inputs.length
|
||||
this._treatAsBatchInput = treatAsBatchInput;
|
||||
this._batchSize = inputs.length;
|
||||
|
||||
inputs.forEach((input, idx) => {
|
||||
|
||||
if (isTensor3D(input)) {
|
||||
this._imageTensors[idx] = input
|
||||
this._inputDimensions[idx] = input.shape
|
||||
return
|
||||
this._imageTensors[idx] = input;
|
||||
this._inputDimensions[idx] = input.shape;
|
||||
return;
|
||||
}
|
||||
|
||||
if (isTensor4D(input)) {
|
||||
const batchSize = (input as any).shape[0]
|
||||
const batchSize = (input as any).shape[0];
|
||||
if (batchSize !== 1) {
|
||||
throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`)
|
||||
throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
|
||||
}
|
||||
|
||||
this._imageTensors[idx] = input
|
||||
this._inputDimensions[idx] = (input as any).shape.slice(1)
|
||||
return
|
||||
this._imageTensors[idx] = input;
|
||||
this._inputDimensions[idx] = (input as any).shape.slice(1);
|
||||
return;
|
||||
}
|
||||
|
||||
const canvas = (input as any) instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input)
|
||||
this._canvases[idx] = canvas
|
||||
this._inputDimensions[idx] = [canvas.height, canvas.width, 3]
|
||||
})
|
||||
const canvas = (input as any) instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input);
|
||||
this._canvases[idx] = canvas;
|
||||
this._inputDimensions[idx] = [canvas.height, canvas.width, 3];
|
||||
});
|
||||
}
|
||||
|
||||
public get imageTensors(): Array<tf.Tensor3D | tf.Tensor4D> {
|
||||
return this._imageTensors
|
||||
return this._imageTensors;
|
||||
}
|
||||
|
||||
public get canvases(): HTMLCanvasElement[] {
|
||||
return this._canvases
|
||||
return this._canvases;
|
||||
}
|
||||
|
||||
public get isBatchInput(): boolean {
|
||||
return this.batchSize > 1 || this._treatAsBatchInput
|
||||
return this.batchSize > 1 || this._treatAsBatchInput;
|
||||
}
|
||||
|
||||
public get batchSize(): number {
|
||||
return this._batchSize
|
||||
return this._batchSize;
|
||||
}
|
||||
|
||||
public get inputDimensions(): number[][] {
|
||||
return this._inputDimensions
|
||||
return this._inputDimensions;
|
||||
}
|
||||
|
||||
public get inputSize(): number | undefined {
|
||||
return this._inputSize
|
||||
return this._inputSize;
|
||||
}
|
||||
|
||||
public get reshapedInputDimensions(): Dimensions[] {
|
||||
return range(this.batchSize, 0, 1).map(
|
||||
(_, batchIdx) => this.getReshapedInputDimensions(batchIdx)
|
||||
)
|
||||
(_, batchIdx) => this.getReshapedInputDimensions(batchIdx),
|
||||
);
|
||||
}
|
||||
|
||||
public getInput(batchIdx: number): tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement {
|
||||
return this.canvases[batchIdx] || this.imageTensors[batchIdx]
|
||||
return this.canvases[batchIdx] || this.imageTensors[batchIdx];
|
||||
}
|
||||
|
||||
public getInputDimensions(batchIdx: number): number[] {
|
||||
return this._inputDimensions[batchIdx]
|
||||
return this._inputDimensions[batchIdx];
|
||||
}
|
||||
|
||||
public getInputHeight(batchIdx: number): number {
|
||||
return this._inputDimensions[batchIdx][0]
|
||||
return this._inputDimensions[batchIdx][0];
|
||||
}
|
||||
|
||||
public getInputWidth(batchIdx: number): number {
|
||||
return this._inputDimensions[batchIdx][1]
|
||||
return this._inputDimensions[batchIdx][1];
|
||||
}
|
||||
|
||||
public getReshapedInputDimensions(batchIdx: number): Dimensions {
|
||||
if (typeof this.inputSize !== 'number') {
|
||||
throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet')
|
||||
throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet');
|
||||
}
|
||||
|
||||
const width = this.getInputWidth(batchIdx)
|
||||
const height = this.getInputHeight(batchIdx)
|
||||
return computeReshapedDimensions({ width, height }, this.inputSize)
|
||||
const width = this.getInputWidth(batchIdx);
|
||||
const height = this.getInputHeight(batchIdx);
|
||||
return computeReshapedDimensions({ width, height }, this.inputSize);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -119,39 +124,37 @@ export class NetInput {
|
|||
* @returns The batch tensor.
|
||||
*/
|
||||
public toBatchTensor(inputSize: number, isCenterInputs: boolean = true): tf.Tensor4D {
|
||||
|
||||
this._inputSize = inputSize
|
||||
this._inputSize = inputSize;
|
||||
|
||||
return tf.tidy(() => {
|
||||
|
||||
const inputTensors = range(this.batchSize, 0, 1).map(batchIdx => {
|
||||
const input = this.getInput(batchIdx)
|
||||
const inputTensors = range(this.batchSize, 0, 1).map((batchIdx) => {
|
||||
const input = this.getInput(batchIdx);
|
||||
|
||||
if (input instanceof tf.Tensor) {
|
||||
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
|
||||
let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>()
|
||||
let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>();
|
||||
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
|
||||
imgTensor = padToSquare(imgTensor, isCenterInputs)
|
||||
imgTensor = padToSquare(imgTensor, isCenterInputs);
|
||||
|
||||
if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
|
||||
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize])
|
||||
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]);
|
||||
}
|
||||
|
||||
return imgTensor.as3D(inputSize, inputSize, 3)
|
||||
return imgTensor.as3D(inputSize, inputSize, 3);
|
||||
}
|
||||
|
||||
if (input instanceof env.getEnv().Canvas) {
|
||||
return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs))
|
||||
return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs));
|
||||
}
|
||||
|
||||
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`)
|
||||
})
|
||||
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`);
|
||||
});
|
||||
|
||||
// const batchTensor = tf.stack(inputTensors.map(t => t.toFloat())).as4D(this.batchSize, inputSize, inputSize, 3)
|
||||
const batchTensor = tf.stack(inputTensors.map(t => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3)
|
||||
const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3);
|
||||
// const batchTensor = tf.stack(inputTensors.map(t => tf.Tensor.as4D(tf.cast(t, 'float32'))), this.batchSize, inputSize, inputSize, 3);
|
||||
|
||||
return batchTensor
|
||||
})
|
||||
return batchTensor;
|
||||
});
|
||||
}
|
||||
}
|
|
@ -2,27 +2,28 @@ import { env } from '../env/index';
|
|||
import { isMediaLoaded } from './isMediaLoaded';
|
||||
|
||||
export function awaitMediaLoaded(media: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement) {
|
||||
|
||||
// eslint-disable-next-line consistent-return
|
||||
return new Promise((resolve, reject) => {
|
||||
if (media instanceof env.getEnv().Canvas || isMediaLoaded(media)) {
|
||||
return resolve(null)
|
||||
}
|
||||
|
||||
function onLoad(e: Event) {
|
||||
if (!e.currentTarget) return
|
||||
e.currentTarget.removeEventListener('load', onLoad)
|
||||
e.currentTarget.removeEventListener('error', onError)
|
||||
resolve(e)
|
||||
return resolve(null);
|
||||
}
|
||||
|
||||
function onError(e: Event) {
|
||||
if (!e.currentTarget) return
|
||||
e.currentTarget.removeEventListener('load', onLoad)
|
||||
e.currentTarget.removeEventListener('error', onError)
|
||||
reject(e)
|
||||
if (!e.currentTarget) return;
|
||||
// eslint-disable-next-line no-use-before-define
|
||||
e.currentTarget.removeEventListener('load', onLoad);
|
||||
e.currentTarget.removeEventListener('error', onError);
|
||||
reject(e);
|
||||
}
|
||||
|
||||
media.addEventListener('load', onLoad)
|
||||
media.addEventListener('error', onError)
|
||||
})
|
||||
function onLoad(e: Event) {
|
||||
if (!e.currentTarget) return;
|
||||
e.currentTarget.removeEventListener('load', onLoad);
|
||||
e.currentTarget.removeEventListener('error', onError);
|
||||
resolve(e);
|
||||
}
|
||||
|
||||
media.addEventListener('load', onLoad);
|
||||
media.addEventListener('error', onError);
|
||||
});
|
||||
}
|
|
@ -2,22 +2,16 @@ import { env } from '../env/index';
|
|||
|
||||
export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (!(buf instanceof Blob)) {
|
||||
return reject('bufferToImage - expected buf to be of type: Blob')
|
||||
}
|
||||
|
||||
const reader = new FileReader()
|
||||
if (!(buf instanceof Blob)) reject(new Error('bufferToImage - expected buf to be of type: Blob'));
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => {
|
||||
if (typeof reader.result !== 'string') {
|
||||
return reject('bufferToImage - expected reader.result to be a string, in onload')
|
||||
}
|
||||
|
||||
const img = env.getEnv().createImageElement()
|
||||
img.onload = () => resolve(img)
|
||||
img.onerror = reject
|
||||
img.src = reader.result
|
||||
}
|
||||
reader.onerror = reject
|
||||
reader.readAsDataURL(buf)
|
||||
})
|
||||
if (typeof reader.result !== 'string') reject(new Error('bufferToImage - expected reader.result to be a string, in onload'));
|
||||
const img = env.getEnv().createImageElement();
|
||||
img.onload = () => resolve(img);
|
||||
img.onerror = reject;
|
||||
img.src = reader.result as string;
|
||||
};
|
||||
reader.onerror = reject;
|
||||
reader.readAsDataURL(buf);
|
||||
});
|
||||
}
|
|
@ -5,29 +5,27 @@ import { getMediaDimensions } from './getMediaDimensions';
|
|||
import { isMediaLoaded } from './isMediaLoaded';
|
||||
|
||||
export function createCanvas({ width, height }: IDimensions): HTMLCanvasElement {
|
||||
|
||||
const { createCanvasElement } = env.getEnv()
|
||||
const canvas = createCanvasElement()
|
||||
canvas.width = width
|
||||
canvas.height = height
|
||||
return canvas
|
||||
const { createCanvasElement } = env.getEnv();
|
||||
const canvas = createCanvasElement();
|
||||
canvas.width = width;
|
||||
canvas.height = height;
|
||||
return canvas;
|
||||
}
|
||||
|
||||
export function createCanvasFromMedia(media: HTMLImageElement | HTMLVideoElement | ImageData, dims?: IDimensions): HTMLCanvasElement {
|
||||
|
||||
const { ImageData } = env.getEnv()
|
||||
const { ImageData } = env.getEnv();
|
||||
|
||||
if (!(media instanceof ImageData) && !isMediaLoaded(media)) {
|
||||
throw new Error('createCanvasFromMedia - media has not finished loading yet')
|
||||
throw new Error('createCanvasFromMedia - media has not finished loading yet');
|
||||
}
|
||||
|
||||
const { width, height } = dims || getMediaDimensions(media)
|
||||
const canvas = createCanvas({ width, height })
|
||||
const { width, height } = dims || getMediaDimensions(media);
|
||||
const canvas = createCanvas({ width, height });
|
||||
|
||||
if (media instanceof ImageData) {
|
||||
getContext2dOrThrow(canvas).putImageData(media, 0, 0)
|
||||
getContext2dOrThrow(canvas).putImageData(media, 0, 0);
|
||||
} else {
|
||||
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
|
||||
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height);
|
||||
}
|
||||
return canvas
|
||||
return canvas;
|
||||
}
|
|
@ -16,31 +16,30 @@ import { isTensor3D, isTensor4D } from '../utils/index';
|
|||
*/
|
||||
export async function extractFaceTensors(
|
||||
imageTensor: tf.Tensor3D | tf.Tensor4D,
|
||||
detections: Array<FaceDetection | Rect>
|
||||
detections: Array<FaceDetection | Rect>,
|
||||
): Promise<tf.Tensor3D[]> {
|
||||
|
||||
if (!isTensor3D(imageTensor) && !isTensor4D(imageTensor)) {
|
||||
throw new Error('extractFaceTensors - expected image tensor to be 3D or 4D')
|
||||
throw new Error('extractFaceTensors - expected image tensor to be 3D or 4D');
|
||||
}
|
||||
|
||||
if (isTensor4D(imageTensor) && imageTensor.shape[0] > 1) {
|
||||
throw new Error('extractFaceTensors - batchSize > 1 not supported')
|
||||
throw new Error('extractFaceTensors - batchSize > 1 not supported');
|
||||
}
|
||||
|
||||
return tf.tidy(() => {
|
||||
const [imgHeight, imgWidth, numChannels] = imageTensor.shape.slice(isTensor4D(imageTensor) ? 1 : 0)
|
||||
const [imgHeight, imgWidth, numChannels] = imageTensor.shape.slice(isTensor4D(imageTensor) ? 1 : 0);
|
||||
|
||||
const boxes = detections.map(
|
||||
det => det instanceof FaceDetection
|
||||
(det) => (det instanceof FaceDetection
|
||||
? det.forSize(imgWidth, imgHeight).box
|
||||
: det
|
||||
: det),
|
||||
)
|
||||
.map(box => box.clipAtImageBorders(imgWidth, imgHeight))
|
||||
.map((box) => box.clipAtImageBorders(imgWidth, imgHeight));
|
||||
|
||||
const faceTensors = boxes.map(({ x, y, width, height }) =>
|
||||
tf.slice3d(imageTensor.as3D(imgHeight, imgWidth, numChannels), [y, x, 0], [height, width, numChannels])
|
||||
)
|
||||
const faceTensors = boxes.map(({
|
||||
x, y, width, height,
|
||||
}) => tf.slice3d(imageTensor.as3D(imgHeight, imgWidth, numChannels), [y, x, 0], [height, width, numChannels]));
|
||||
|
||||
return faceTensors
|
||||
})
|
||||
return faceTensors;
|
||||
});
|
||||
}
|
|
@ -16,38 +16,39 @@ import { TNetInput } from './types';
|
|||
*/
|
||||
export async function extractFaces(
|
||||
input: TNetInput,
|
||||
detections: Array<FaceDetection | Rect>
|
||||
detections: Array<FaceDetection | Rect>,
|
||||
): Promise<HTMLCanvasElement[]> {
|
||||
const { Canvas } = env.getEnv();
|
||||
|
||||
const { Canvas } = env.getEnv()
|
||||
|
||||
let canvas = input as HTMLCanvasElement
|
||||
let canvas = input as HTMLCanvasElement;
|
||||
|
||||
if (!(input instanceof Canvas)) {
|
||||
const netInput = await toNetInput(input)
|
||||
const netInput = await toNetInput(input);
|
||||
|
||||
if (netInput.batchSize > 1) {
|
||||
throw new Error('extractFaces - batchSize > 1 not supported')
|
||||
throw new Error('extractFaces - batchSize > 1 not supported');
|
||||
}
|
||||
|
||||
const tensorOrCanvas = netInput.getInput(0)
|
||||
const tensorOrCanvas = netInput.getInput(0);
|
||||
canvas = tensorOrCanvas instanceof Canvas
|
||||
? tensorOrCanvas
|
||||
: await imageTensorToCanvas(tensorOrCanvas)
|
||||
: await imageTensorToCanvas(tensorOrCanvas);
|
||||
}
|
||||
|
||||
const ctx = getContext2dOrThrow(canvas)
|
||||
const ctx = getContext2dOrThrow(canvas);
|
||||
const boxes = detections.map(
|
||||
det => det instanceof FaceDetection
|
||||
(det) => (det instanceof FaceDetection
|
||||
? det.forSize(canvas.width, canvas.height).box.floor()
|
||||
: det
|
||||
: det),
|
||||
)
|
||||
.map(box => box.clipAtImageBorders(canvas.width, canvas.height))
|
||||
.map((box) => box.clipAtImageBorders(canvas.width, canvas.height));
|
||||
|
||||
return boxes.map(({ x, y, width, height }) => {
|
||||
const faceImg = createCanvas({ width, height })
|
||||
return boxes.map(({
|
||||
x, y, width, height,
|
||||
}) => {
|
||||
const faceImg = createCanvas({ width, height });
|
||||
getContext2dOrThrow(faceImg)
|
||||
.putImageData(ctx.getImageData(x, y, width, height), 0, 0)
|
||||
return faceImg
|
||||
})
|
||||
.putImageData(ctx.getImageData(x, y, width, height), 0, 0);
|
||||
return faceImg;
|
||||
});
|
||||
}
|
|
@ -2,11 +2,11 @@ import { bufferToImage } from './bufferToImage';
|
|||
import { fetchOrThrow } from './fetchOrThrow';
|
||||
|
||||
export async function fetchImage(uri: string): Promise<HTMLImageElement> {
|
||||
const res = await fetchOrThrow(uri)
|
||||
const blob = await (res).blob()
|
||||
const res = await fetchOrThrow(uri);
|
||||
const blob = await (res).blob();
|
||||
|
||||
if (!blob.type.startsWith('image/')) {
|
||||
throw new Error(`fetchImage - expected blob type to be of type image/*, instead have: ${blob.type}, for url: ${res.url}`)
|
||||
throw new Error(`fetchImage - expected blob type to be of type image/*, instead have: ${blob.type}, for url: ${res.url}`);
|
||||
}
|
||||
return bufferToImage(blob)
|
||||
return bufferToImage(blob);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { fetchOrThrow } from './fetchOrThrow';
|
||||
|
||||
export async function fetchJson<T>(uri: string): Promise<T> {
|
||||
return (await fetchOrThrow(uri)).json()
|
||||
return (await fetchOrThrow(uri)).json();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { fetchOrThrow } from './fetchOrThrow';
|
||||
|
||||
export async function fetchNetWeights(uri: string): Promise<Float32Array> {
|
||||
return new Float32Array(await (await fetchOrThrow(uri)).arrayBuffer())
|
||||
return new Float32Array(await (await fetchOrThrow(uri)).arrayBuffer());
|
||||
}
|
||||
|
|
|
@ -2,13 +2,13 @@ import { env } from '../env/index';
|
|||
|
||||
export async function fetchOrThrow(
|
||||
url: string,
|
||||
init?: RequestInit
|
||||
// eslint-disable-next-line no-undef
|
||||
init?: RequestInit,
|
||||
): Promise<Response> {
|
||||
|
||||
const fetch = env.getEnv().fetch
|
||||
const res = await fetch(url, init)
|
||||
const { fetch } = env.getEnv();
|
||||
const res = await fetch(url, init);
|
||||
if (!(res.status < 400)) {
|
||||
throw new Error(`failed to fetch: (${res.status}) ${res.statusText}, from url: ${res.url}`)
|
||||
throw new Error(`failed to fetch: (${res.status}) ${res.statusText}, from url: ${res.url}`);
|
||||
}
|
||||
return res
|
||||
return res;
|
||||
}
|
|
@ -2,23 +2,22 @@ import { env } from '../env/index';
|
|||
import { resolveInput } from './resolveInput';
|
||||
|
||||
export function getContext2dOrThrow(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D): CanvasRenderingContext2D {
|
||||
|
||||
const { Canvas, CanvasRenderingContext2D } = env.getEnv()
|
||||
const { Canvas, CanvasRenderingContext2D } = env.getEnv();
|
||||
|
||||
if (canvasArg instanceof CanvasRenderingContext2D) {
|
||||
return canvasArg
|
||||
return canvasArg;
|
||||
}
|
||||
|
||||
const canvas = resolveInput(canvasArg)
|
||||
const canvas = resolveInput(canvasArg);
|
||||
|
||||
if (!(canvas instanceof Canvas)) {
|
||||
throw new Error('resolveContext2d - expected canvas to be of instance of Canvas')
|
||||
throw new Error('resolveContext2d - expected canvas to be of instance of Canvas');
|
||||
}
|
||||
|
||||
const ctx = canvas.getContext('2d')
|
||||
const ctx = canvas.getContext('2d');
|
||||
if (!ctx) {
|
||||
throw new Error('resolveContext2d - canvas 2d context is null')
|
||||
throw new Error('resolveContext2d - canvas 2d context is null');
|
||||
}
|
||||
|
||||
return ctx
|
||||
return ctx;
|
||||
}
|
|
@ -2,14 +2,13 @@ import { Dimensions, IDimensions } from '../classes/Dimensions';
|
|||
import { env } from '../env/index';
|
||||
|
||||
export function getMediaDimensions(input: HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | IDimensions): Dimensions {
|
||||
|
||||
const { Image, Video } = env.getEnv()
|
||||
const { Image, Video } = env.getEnv();
|
||||
|
||||
if (input instanceof Image) {
|
||||
return new Dimensions(input.naturalWidth, input.naturalHeight)
|
||||
return new Dimensions(input.naturalWidth, input.naturalHeight);
|
||||
}
|
||||
if (input instanceof Video) {
|
||||
return new Dimensions(input.videoWidth, input.videoHeight)
|
||||
return new Dimensions(input.videoWidth, input.videoHeight);
|
||||
}
|
||||
return new Dimensions(input.width, input.height)
|
||||
return new Dimensions(input.width, input.height);
|
||||
}
|
||||
|
|
|
@ -5,16 +5,15 @@ import { isTensor4D } from '../utils/index';
|
|||
|
||||
export async function imageTensorToCanvas(
|
||||
imgTensor: tf.Tensor,
|
||||
canvas?: HTMLCanvasElement
|
||||
canvas?: HTMLCanvasElement,
|
||||
): Promise<HTMLCanvasElement> {
|
||||
const targetCanvas = canvas || env.getEnv().createCanvasElement();
|
||||
|
||||
const targetCanvas = canvas || env.getEnv().createCanvasElement()
|
||||
const [height, width, numChannels] = imgTensor.shape.slice(isTensor4D(imgTensor) ? 1 : 0);
|
||||
const imgTensor3D = tf.tidy(() => imgTensor.as3D(height, width, numChannels).toInt());
|
||||
await tf.browser.toPixels(imgTensor3D, targetCanvas);
|
||||
|
||||
const [height, width, numChannels] = imgTensor.shape.slice(isTensor4D(imgTensor) ? 1 : 0)
|
||||
const imgTensor3D = tf.tidy(() => imgTensor.as3D(height, width, numChannels).toInt())
|
||||
await tf.browser.toPixels(imgTensor3D, targetCanvas)
|
||||
imgTensor3D.dispose();
|
||||
|
||||
imgTensor3D.dispose()
|
||||
|
||||
return targetCanvas
|
||||
return targetCanvas;
|
||||
}
|
|
@ -4,25 +4,24 @@ import { getContext2dOrThrow } from './getContext2dOrThrow';
|
|||
import { getMediaDimensions } from './getMediaDimensions';
|
||||
|
||||
export function imageToSquare(input: HTMLImageElement | HTMLCanvasElement, inputSize: number, centerImage: boolean = false) {
|
||||
|
||||
const { Image, Canvas } = env.getEnv()
|
||||
const { Image, Canvas } = env.getEnv();
|
||||
|
||||
if (!(input instanceof Image || input instanceof Canvas)) {
|
||||
throw new Error('imageToSquare - expected arg0 to be HTMLImageElement | HTMLCanvasElement')
|
||||
throw new Error('imageToSquare - expected arg0 to be HTMLImageElement | HTMLCanvasElement');
|
||||
}
|
||||
|
||||
const dims = getMediaDimensions(input)
|
||||
const scale = inputSize / Math.max(dims.height, dims.width)
|
||||
const width = scale * dims.width
|
||||
const height = scale * dims.height
|
||||
const dims = getMediaDimensions(input);
|
||||
const scale = inputSize / Math.max(dims.height, dims.width);
|
||||
const width = scale * dims.width;
|
||||
const height = scale * dims.height;
|
||||
|
||||
const targetCanvas = createCanvas({ width: inputSize, height: inputSize })
|
||||
const inputCanvas = input instanceof Canvas ? input : createCanvasFromMedia(input)
|
||||
const targetCanvas = createCanvas({ width: inputSize, height: inputSize });
|
||||
const inputCanvas = input instanceof Canvas ? input : createCanvasFromMedia(input);
|
||||
|
||||
const offset = Math.abs(width - height) / 2
|
||||
const dx = centerImage && width < height ? offset : 0
|
||||
const dy = centerImage && height < width ? offset : 0
|
||||
getContext2dOrThrow(targetCanvas).drawImage(inputCanvas, dx, dy, width, height)
|
||||
const offset = Math.abs(width - height) / 2;
|
||||
const dx = centerImage && width < height ? offset : 0;
|
||||
const dy = centerImage && height < width ? offset : 0;
|
||||
getContext2dOrThrow(targetCanvas).drawImage(inputCanvas, dx, dy, width, height);
|
||||
|
||||
return targetCanvas
|
||||
return targetCanvas;
|
||||
}
|
|
@ -1,21 +1,21 @@
|
|||
export * from './awaitMediaLoaded'
|
||||
export * from './bufferToImage'
|
||||
export * from './createCanvas'
|
||||
export * from './extractFaces'
|
||||
export * from './extractFaceTensors'
|
||||
export * from './fetchImage'
|
||||
export * from './fetchJson'
|
||||
export * from './fetchNetWeights'
|
||||
export * from './fetchOrThrow'
|
||||
export * from './getContext2dOrThrow'
|
||||
export * from './getMediaDimensions'
|
||||
export * from './imageTensorToCanvas'
|
||||
export * from './imageToSquare'
|
||||
export * from './isMediaElement'
|
||||
export * from './isMediaLoaded'
|
||||
export * from './loadWeightMap'
|
||||
export * from './matchDimensions'
|
||||
export * from './NetInput'
|
||||
export * from './resolveInput'
|
||||
export * from './toNetInput'
|
||||
export * from './types'
|
||||
export * from './awaitMediaLoaded';
|
||||
export * from './bufferToImage';
|
||||
export * from './createCanvas';
|
||||
export * from './extractFaces';
|
||||
export * from './extractFaceTensors';
|
||||
export * from './fetchImage';
|
||||
export * from './fetchJson';
|
||||
export * from './fetchNetWeights';
|
||||
export * from './fetchOrThrow';
|
||||
export * from './getContext2dOrThrow';
|
||||
export * from './getMediaDimensions';
|
||||
export * from './imageTensorToCanvas';
|
||||
export * from './imageToSquare';
|
||||
export * from './isMediaElement';
|
||||
export * from './isMediaLoaded';
|
||||
export * from './loadWeightMap';
|
||||
export * from './matchDimensions';
|
||||
export * from './NetInput';
|
||||
export * from './resolveInput';
|
||||
export * from './toNetInput';
|
||||
export * from './types';
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import { env } from '../env/index';
|
||||
|
||||
export function isMediaElement(input: any) {
|
||||
|
||||
const { Image, Canvas, Video } = env.getEnv()
|
||||
const { Image, Canvas, Video } = env.getEnv();
|
||||
|
||||
return input instanceof Image
|
||||
|| input instanceof Canvas
|
||||
|| input instanceof Video
|
||||
|| input instanceof Video;
|
||||
}
|
|
@ -1,9 +1,8 @@
|
|||
import { env } from '../env/index';
|
||||
|
||||
export function isMediaLoaded(media: HTMLImageElement | HTMLVideoElement) : boolean {
|
||||
|
||||
const { Image, Video } = env.getEnv()
|
||||
const { Image, Video } = env.getEnv();
|
||||
|
||||
return (media instanceof Image && media.complete)
|
||||
|| (media instanceof Video && media.readyState >= 3)
|
||||
|| (media instanceof Video && media.readyState >= 3);
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ export async function loadWeightMap(
|
|||
uri: string | undefined,
|
||||
defaultModelName: string,
|
||||
): Promise<tf.NamedTensorMap> {
|
||||
const { manifestUri, modelBaseUri } = getModelUris(uri, defaultModelName)
|
||||
let manifest = await fetchJson<tf.io.WeightsManifestConfig>(manifestUri)
|
||||
const { manifestUri, modelBaseUri } = getModelUris(uri, defaultModelName);
|
||||
const manifest = await fetchJson<tf.io.WeightsManifestConfig>(manifestUri);
|
||||
// if (manifest['weightsManifest']) manifest = manifest['weightsManifest'];
|
||||
return tf.io.loadWeights(manifest, modelBaseUri)
|
||||
return tf.io.loadWeights(manifest, modelBaseUri);
|
||||
}
|
|
@ -4,8 +4,8 @@ import { getMediaDimensions } from './getMediaDimensions';
|
|||
export function matchDimensions(input: IDimensions, reference: IDimensions, useMediaDimensions: boolean = false) {
|
||||
const { width, height } = useMediaDimensions
|
||||
? getMediaDimensions(reference)
|
||||
: reference
|
||||
input.width = width
|
||||
input.height = height
|
||||
return { width, height }
|
||||
: reference;
|
||||
input.width = width;
|
||||
input.height = height;
|
||||
return { width, height };
|
||||
}
|
|
@ -2,7 +2,7 @@ import { env } from '../env/index';
|
|||
|
||||
export function resolveInput(arg: string | any) {
|
||||
if (!env.isNodejs() && typeof arg === 'string') {
|
||||
return document.getElementById(arg)
|
||||
return document.getElementById(arg);
|
||||
}
|
||||
return arg
|
||||
return arg;
|
||||
}
|
|
@ -14,44 +14,43 @@ import { TNetInput } from './types';
|
|||
*/
|
||||
export async function toNetInput(inputs: TNetInput): Promise<NetInput> {
|
||||
if (inputs instanceof NetInput) {
|
||||
return inputs
|
||||
return inputs;
|
||||
}
|
||||
|
||||
let inputArgArray = Array.isArray(inputs)
|
||||
const inputArgArray = Array.isArray(inputs)
|
||||
? inputs
|
||||
: [inputs]
|
||||
: [inputs];
|
||||
|
||||
if (!inputArgArray.length) {
|
||||
throw new Error('toNetInput - empty array passed as input')
|
||||
throw new Error('toNetInput - empty array passed as input');
|
||||
}
|
||||
|
||||
const getIdxHint = (idx: number) => Array.isArray(inputs) ? ` at input index ${idx}:` : ''
|
||||
const getIdxHint = (idx: number) => (Array.isArray(inputs) ? ` at input index ${idx}:` : '');
|
||||
|
||||
const inputArray = inputArgArray.map(resolveInput)
|
||||
const inputArray = inputArgArray.map(resolveInput);
|
||||
|
||||
inputArray.forEach((input, i) => {
|
||||
if (!isMediaElement(input) && !isTensor3D(input) && !isTensor4D(input)) {
|
||||
|
||||
if (typeof inputArgArray[i] === 'string') {
|
||||
throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`)
|
||||
throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`);
|
||||
}
|
||||
|
||||
throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`)
|
||||
throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`);
|
||||
}
|
||||
|
||||
if (isTensor4D(input)) {
|
||||
// if tf.Tensor4D is passed in the input array, the batch size has to be 1
|
||||
const batchSize = input.shape[0]
|
||||
const batchSize = input.shape[0];
|
||||
if (batchSize !== 1) {
|
||||
throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`)
|
||||
throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
// wait for all media elements being loaded
|
||||
await Promise.all(
|
||||
inputArray.map(input => isMediaElement(input) && awaitMediaLoaded(input))
|
||||
)
|
||||
inputArray.map((input) => isMediaElement(input) && awaitMediaLoaded(input)),
|
||||
);
|
||||
|
||||
return new NetInput(inputArray, Array.isArray(inputs))
|
||||
return new NetInput(inputArray, Array.isArray(inputs));
|
||||
}
|
|
@ -1,6 +1,9 @@
|
|||
/* eslint-disable max-classes-per-file */
|
||||
import { Box, IBoundingBox, IRect } from '../classes/index';
|
||||
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
|
||||
import { AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions } from './DrawTextField';
|
||||
import {
|
||||
AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions,
|
||||
} from './DrawTextField';
|
||||
|
||||
export interface IDrawBoxOptions {
|
||||
boxColor?: string
|
||||
|
@ -11,49 +14,57 @@ export interface IDrawBoxOptions {
|
|||
|
||||
export class DrawBoxOptions {
|
||||
public boxColor: string
|
||||
|
||||
public lineWidth: number
|
||||
|
||||
public drawLabelOptions: DrawTextFieldOptions
|
||||
|
||||
public label?: string
|
||||
|
||||
constructor(options: IDrawBoxOptions = {}) {
|
||||
const { boxColor, lineWidth, label, drawLabelOptions } = options
|
||||
this.boxColor = boxColor || 'rgba(0, 0, 255, 1)'
|
||||
this.lineWidth = lineWidth || 2
|
||||
this.label = label
|
||||
const {
|
||||
boxColor, lineWidth, label, drawLabelOptions,
|
||||
} = options;
|
||||
this.boxColor = boxColor || 'rgba(0, 0, 255, 1)';
|
||||
this.lineWidth = lineWidth || 2;
|
||||
this.label = label;
|
||||
|
||||
const defaultDrawLabelOptions = {
|
||||
anchorPosition: AnchorPosition.BOTTOM_LEFT,
|
||||
backgroundColor: this.boxColor
|
||||
}
|
||||
this.drawLabelOptions = new DrawTextFieldOptions(Object.assign({}, defaultDrawLabelOptions, drawLabelOptions))
|
||||
backgroundColor: this.boxColor,
|
||||
};
|
||||
this.drawLabelOptions = new DrawTextFieldOptions({ ...defaultDrawLabelOptions, ...drawLabelOptions });
|
||||
}
|
||||
}
|
||||
|
||||
export class DrawBox {
|
||||
public box: Box
|
||||
|
||||
public options: DrawBoxOptions
|
||||
|
||||
constructor(
|
||||
box: IBoundingBox | IRect,
|
||||
options: IDrawBoxOptions = {}
|
||||
options: IDrawBoxOptions = {},
|
||||
) {
|
||||
this.box = new Box(box)
|
||||
this.options = new DrawBoxOptions(options)
|
||||
this.box = new Box(box);
|
||||
this.options = new DrawBoxOptions(options);
|
||||
}
|
||||
|
||||
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
|
||||
const ctx = getContext2dOrThrow(canvasArg)
|
||||
const ctx = getContext2dOrThrow(canvasArg);
|
||||
|
||||
const { boxColor, lineWidth } = this.options
|
||||
const { boxColor, lineWidth } = this.options;
|
||||
|
||||
const { x, y, width, height } = this.box
|
||||
ctx.strokeStyle = boxColor
|
||||
ctx.lineWidth = lineWidth
|
||||
ctx.strokeRect(x, y, width, height)
|
||||
const {
|
||||
x, y, width, height,
|
||||
} = this.box;
|
||||
ctx.strokeStyle = boxColor;
|
||||
ctx.lineWidth = lineWidth;
|
||||
ctx.strokeRect(x, y, width, height);
|
||||
|
||||
const { label } = this.options
|
||||
const { label } = this.options;
|
||||
if (label) {
|
||||
new DrawTextField([label], { x: x - (lineWidth / 2), y }, this.options.drawLabelOptions).draw(canvasArg)
|
||||
new DrawTextField([label], { x: x - (lineWidth / 2), y }, this.options.drawLabelOptions).draw(canvasArg);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
/* eslint-disable max-classes-per-file */
|
||||
import { IPoint } from '../classes/index';
|
||||
import { FaceLandmarks } from '../classes/FaceLandmarks';
|
||||
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
|
||||
|
@ -17,62 +18,72 @@ export interface IDrawFaceLandmarksOptions {
|
|||
|
||||
export class DrawFaceLandmarksOptions {
|
||||
public drawLines: boolean
|
||||
|
||||
public drawPoints: boolean
|
||||
|
||||
public lineWidth: number
|
||||
|
||||
public pointSize: number
|
||||
|
||||
public lineColor: string
|
||||
|
||||
public pointColor: string
|
||||
|
||||
constructor(options: IDrawFaceLandmarksOptions = {}) {
|
||||
const { drawLines = true, drawPoints = true, lineWidth, lineColor, pointSize, pointColor } = options
|
||||
this.drawLines = drawLines
|
||||
this.drawPoints = drawPoints
|
||||
this.lineWidth = lineWidth || 1
|
||||
this.pointSize = pointSize || 2
|
||||
this.lineColor = lineColor || 'rgba(0, 255, 255, 1)'
|
||||
this.pointColor = pointColor || 'rgba(255, 0, 255, 1)'
|
||||
const {
|
||||
drawLines = true, drawPoints = true, lineWidth, lineColor, pointSize, pointColor,
|
||||
} = options;
|
||||
this.drawLines = drawLines;
|
||||
this.drawPoints = drawPoints;
|
||||
this.lineWidth = lineWidth || 1;
|
||||
this.pointSize = pointSize || 2;
|
||||
this.lineColor = lineColor || 'rgba(0, 255, 255, 1)';
|
||||
this.pointColor = pointColor || 'rgba(255, 0, 255, 1)';
|
||||
}
|
||||
}
|
||||
|
||||
export class DrawFaceLandmarks {
|
||||
public faceLandmarks: FaceLandmarks
|
||||
|
||||
public options: DrawFaceLandmarksOptions
|
||||
|
||||
constructor(
|
||||
faceLandmarks: FaceLandmarks,
|
||||
options: IDrawFaceLandmarksOptions = {}
|
||||
options: IDrawFaceLandmarksOptions = {},
|
||||
) {
|
||||
this.faceLandmarks = faceLandmarks
|
||||
this.options = new DrawFaceLandmarksOptions(options)
|
||||
this.faceLandmarks = faceLandmarks;
|
||||
this.options = new DrawFaceLandmarksOptions(options);
|
||||
}
|
||||
|
||||
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
|
||||
const ctx = getContext2dOrThrow(canvasArg)
|
||||
const ctx = getContext2dOrThrow(canvasArg);
|
||||
|
||||
const { drawLines, drawPoints, lineWidth, lineColor, pointSize, pointColor } = this.options
|
||||
const {
|
||||
drawLines, drawPoints, lineWidth, lineColor, pointSize, pointColor,
|
||||
} = this.options;
|
||||
|
||||
if (drawLines && this.faceLandmarks instanceof FaceLandmarks68) {
|
||||
ctx.strokeStyle = lineColor
|
||||
ctx.lineWidth = lineWidth
|
||||
drawContour(ctx, this.faceLandmarks.getJawOutline())
|
||||
drawContour(ctx, this.faceLandmarks.getLeftEyeBrow())
|
||||
drawContour(ctx, this.faceLandmarks.getRightEyeBrow())
|
||||
drawContour(ctx, this.faceLandmarks.getNose())
|
||||
drawContour(ctx, this.faceLandmarks.getLeftEye(), true)
|
||||
drawContour(ctx, this.faceLandmarks.getRightEye(), true)
|
||||
drawContour(ctx, this.faceLandmarks.getMouth(), true)
|
||||
ctx.strokeStyle = lineColor;
|
||||
ctx.lineWidth = lineWidth;
|
||||
drawContour(ctx, this.faceLandmarks.getJawOutline());
|
||||
drawContour(ctx, this.faceLandmarks.getLeftEyeBrow());
|
||||
drawContour(ctx, this.faceLandmarks.getRightEyeBrow());
|
||||
drawContour(ctx, this.faceLandmarks.getNose());
|
||||
drawContour(ctx, this.faceLandmarks.getLeftEye(), true);
|
||||
drawContour(ctx, this.faceLandmarks.getRightEye(), true);
|
||||
drawContour(ctx, this.faceLandmarks.getMouth(), true);
|
||||
}
|
||||
|
||||
if (drawPoints) {
|
||||
ctx.strokeStyle = pointColor
|
||||
ctx.fillStyle = pointColor
|
||||
ctx.strokeStyle = pointColor;
|
||||
ctx.fillStyle = pointColor;
|
||||
|
||||
const drawPoint = (pt: IPoint) => {
|
||||
ctx.beginPath()
|
||||
ctx.arc(pt.x, pt.y, pointSize, 0, 2 * Math.PI)
|
||||
ctx.fill()
|
||||
}
|
||||
this.faceLandmarks.positions.forEach(drawPoint)
|
||||
ctx.beginPath();
|
||||
ctx.arc(pt.x, pt.y, pointSize, 0, 2 * Math.PI);
|
||||
ctx.fill();
|
||||
};
|
||||
this.faceLandmarks.positions.forEach(drawPoint);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -81,17 +92,18 @@ export type DrawFaceLandmarksInput = FaceLandmarks | WithFaceLandmarks<WithFaceD
|
|||
|
||||
export function drawFaceLandmarks(
|
||||
canvasArg: string | HTMLCanvasElement,
|
||||
faceLandmarks: DrawFaceLandmarksInput | Array<DrawFaceLandmarksInput>
|
||||
faceLandmarks: DrawFaceLandmarksInput | Array<DrawFaceLandmarksInput>,
|
||||
) {
|
||||
const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks]
|
||||
faceLandmarksArray.forEach(f => {
|
||||
const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks];
|
||||
faceLandmarksArray.forEach((f) => {
|
||||
// eslint-disable-next-line no-nested-ternary
|
||||
const landmarks = f instanceof FaceLandmarks
|
||||
? f
|
||||
: (isWithFaceLandmarks(f) ? f.landmarks : undefined)
|
||||
: (isWithFaceLandmarks(f) ? f.landmarks : undefined);
|
||||
if (!landmarks) {
|
||||
throw new Error('drawFaceLandmarks - expected faceExpressions to be FaceLandmarks | WithFaceLandmarks<WithFaceDetection<{}>> or array thereof')
|
||||
throw new Error('drawFaceLandmarks - expected faceExpressions to be FaceLandmarks | WithFaceLandmarks<WithFaceDetection<{}>> or array thereof');
|
||||
}
|
||||
|
||||
new DrawFaceLandmarks(landmarks).draw(canvasArg)
|
||||
})
|
||||
new DrawFaceLandmarks(landmarks).draw(canvasArg);
|
||||
});
|
||||
}
|
|
@ -1,11 +1,17 @@
|
|||
/* eslint-disable max-classes-per-file */
|
||||
import { IDimensions, IPoint } from '../classes/index';
|
||||
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
|
||||
import { resolveInput } from '../dom/resolveInput';
|
||||
|
||||
// eslint-disable-next-line no-shadow
|
||||
export enum AnchorPosition {
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
TOP_LEFT = 'TOP_LEFT',
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
TOP_RIGHT = 'TOP_RIGHT',
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
BOTTOM_LEFT = 'BOTTOM_LEFT',
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
BOTTOM_RIGHT = 'BOTTOM_RIGHT'
|
||||
}
|
||||
|
||||
|
@ -20,89 +26,101 @@ export interface IDrawTextFieldOptions {
|
|||
|
||||
export class DrawTextFieldOptions implements IDrawTextFieldOptions {
|
||||
public anchorPosition: AnchorPosition
|
||||
|
||||
public backgroundColor: string
|
||||
|
||||
public fontColor: string
|
||||
|
||||
public fontSize: number
|
||||
|
||||
public fontStyle: string
|
||||
|
||||
public padding: number
|
||||
|
||||
constructor(options: IDrawTextFieldOptions = {}) {
|
||||
const { anchorPosition, backgroundColor, fontColor, fontSize, fontStyle, padding } = options
|
||||
this.anchorPosition = anchorPosition || AnchorPosition.TOP_LEFT
|
||||
this.backgroundColor = backgroundColor || 'rgba(0, 0, 0, 0.5)'
|
||||
this.fontColor = fontColor || 'rgba(255, 255, 255, 1)'
|
||||
this.fontSize = fontSize || 14
|
||||
this.fontStyle = fontStyle || 'Georgia'
|
||||
this.padding = padding || 4
|
||||
const {
|
||||
anchorPosition, backgroundColor, fontColor, fontSize, fontStyle, padding,
|
||||
} = options;
|
||||
this.anchorPosition = anchorPosition || AnchorPosition.TOP_LEFT;
|
||||
this.backgroundColor = backgroundColor || 'rgba(0, 0, 0, 0.5)';
|
||||
this.fontColor = fontColor || 'rgba(255, 255, 255, 1)';
|
||||
this.fontSize = fontSize || 14;
|
||||
this.fontStyle = fontStyle || 'Georgia';
|
||||
this.padding = padding || 4;
|
||||
}
|
||||
}
|
||||
|
||||
export class DrawTextField {
|
||||
public text: string[]
|
||||
|
||||
public anchor : IPoint
|
||||
|
||||
public options: DrawTextFieldOptions
|
||||
|
||||
constructor(
|
||||
text: string | string[] | DrawTextField,
|
||||
anchor: IPoint,
|
||||
options: IDrawTextFieldOptions = {}
|
||||
options: IDrawTextFieldOptions = {},
|
||||
) {
|
||||
// eslint-disable-next-line no-nested-ternary
|
||||
this.text = typeof text === 'string'
|
||||
? [text]
|
||||
: (text instanceof DrawTextField ? text.text : text)
|
||||
this.anchor = anchor
|
||||
this.options = new DrawTextFieldOptions(options)
|
||||
: (text instanceof DrawTextField ? text.text : text);
|
||||
this.anchor = anchor;
|
||||
this.options = new DrawTextFieldOptions(options);
|
||||
}
|
||||
|
||||
measureWidth(ctx: CanvasRenderingContext2D): number {
|
||||
const { padding } = this.options
|
||||
return this.text.map(l => ctx.measureText(l).width).reduce((w0, w1) => w0 < w1 ? w1 : w0, 0) + (2 * padding)
|
||||
const { padding } = this.options;
|
||||
return this.text.map((l) => ctx.measureText(l).width).reduce((w0, w1) => (w0 < w1 ? w1 : w0), 0) + (2 * padding);
|
||||
}
|
||||
|
||||
measureHeight(): number {
|
||||
const { fontSize, padding } = this.options
|
||||
return this.text.length * fontSize + (2 * padding)
|
||||
const { fontSize, padding } = this.options;
|
||||
return this.text.length * fontSize + (2 * padding);
|
||||
}
|
||||
|
||||
getUpperLeft(ctx: CanvasRenderingContext2D, canvasDims?: IDimensions): IPoint {
|
||||
const { anchorPosition } = this.options
|
||||
const isShiftLeft = anchorPosition === AnchorPosition.BOTTOM_RIGHT || anchorPosition === AnchorPosition.TOP_RIGHT
|
||||
const isShiftTop = anchorPosition === AnchorPosition.BOTTOM_LEFT || anchorPosition === AnchorPosition.BOTTOM_RIGHT
|
||||
const { anchorPosition } = this.options;
|
||||
const isShiftLeft = anchorPosition === AnchorPosition.BOTTOM_RIGHT || anchorPosition === AnchorPosition.TOP_RIGHT;
|
||||
const isShiftTop = anchorPosition === AnchorPosition.BOTTOM_LEFT || anchorPosition === AnchorPosition.BOTTOM_RIGHT;
|
||||
|
||||
const textFieldWidth = this.measureWidth(ctx)
|
||||
const textFieldHeight = this.measureHeight()
|
||||
const x = (isShiftLeft ? this.anchor.x - textFieldWidth : this.anchor.x)
|
||||
const y = isShiftTop ? this.anchor.y - textFieldHeight : this.anchor.y
|
||||
const textFieldWidth = this.measureWidth(ctx);
|
||||
const textFieldHeight = this.measureHeight();
|
||||
const x = (isShiftLeft ? this.anchor.x - textFieldWidth : this.anchor.x);
|
||||
const y = isShiftTop ? this.anchor.y - textFieldHeight : this.anchor.y;
|
||||
|
||||
// adjust anchor if text box exceeds canvas borders
|
||||
if (canvasDims) {
|
||||
const { width, height } = canvasDims
|
||||
const newX = Math.max(Math.min(x, width - textFieldWidth), 0)
|
||||
const newY = Math.max(Math.min(y, height - textFieldHeight), 0)
|
||||
return { x: newX, y: newY }
|
||||
const { width, height } = canvasDims;
|
||||
const newX = Math.max(Math.min(x, width - textFieldWidth), 0);
|
||||
const newY = Math.max(Math.min(y, height - textFieldHeight), 0);
|
||||
return { x: newX, y: newY };
|
||||
}
|
||||
return { x, y }
|
||||
return { x, y };
|
||||
}
|
||||
|
||||
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
|
||||
const canvas = resolveInput(canvasArg)
|
||||
const ctx = getContext2dOrThrow(canvas)
|
||||
const canvas = resolveInput(canvasArg);
|
||||
const ctx = getContext2dOrThrow(canvas);
|
||||
|
||||
const { backgroundColor, fontColor, fontSize, fontStyle, padding } = this.options
|
||||
const {
|
||||
backgroundColor, fontColor, fontSize, fontStyle, padding,
|
||||
} = this.options;
|
||||
|
||||
ctx.font = `${fontSize}px ${fontStyle}`
|
||||
const maxTextWidth = this.measureWidth(ctx)
|
||||
const textHeight = this.measureHeight()
|
||||
ctx.font = `${fontSize}px ${fontStyle}`;
|
||||
const maxTextWidth = this.measureWidth(ctx);
|
||||
const textHeight = this.measureHeight();
|
||||
|
||||
ctx.fillStyle = backgroundColor
|
||||
const upperLeft = this.getUpperLeft(ctx, canvas)
|
||||
ctx.fillRect(upperLeft.x, upperLeft.y, maxTextWidth, textHeight)
|
||||
ctx.fillStyle = backgroundColor;
|
||||
const upperLeft = this.getUpperLeft(ctx, canvas);
|
||||
ctx.fillRect(upperLeft.x, upperLeft.y, maxTextWidth, textHeight);
|
||||
|
||||
ctx.fillStyle = fontColor;
|
||||
this.text.forEach((textLine, i) => {
|
||||
const x = padding + upperLeft.x
|
||||
const y = padding + upperLeft.y + ((i + 1) * fontSize)
|
||||
ctx.fillText(textLine, x, y)
|
||||
})
|
||||
const x = padding + upperLeft.x;
|
||||
const y = padding + upperLeft.y + ((i + 1) * fontSize);
|
||||
ctx.fillText(textLine, x, y);
|
||||
});
|
||||
}
|
||||
}
|
|
@ -3,26 +3,26 @@ import { Point } from '../classes/index';
|
|||
export function drawContour(
|
||||
ctx: CanvasRenderingContext2D,
|
||||
points: Point[],
|
||||
isClosed: boolean = false
|
||||
isClosed: boolean = false,
|
||||
) {
|
||||
ctx.beginPath()
|
||||
ctx.beginPath();
|
||||
|
||||
points.slice(1).forEach(({ x, y }, prevIdx) => {
|
||||
const from = points[prevIdx]
|
||||
ctx.moveTo(from.x, from.y)
|
||||
ctx.lineTo(x, y)
|
||||
})
|
||||
const from = points[prevIdx];
|
||||
ctx.moveTo(from.x, from.y);
|
||||
ctx.lineTo(x, y);
|
||||
});
|
||||
|
||||
if (isClosed) {
|
||||
const from = points[points.length - 1]
|
||||
const to = points[0]
|
||||
const from = points[points.length - 1];
|
||||
const to = points[0];
|
||||
if (!from || !to) {
|
||||
return
|
||||
return;
|
||||
}
|
||||
|
||||
ctx.moveTo(from.x, from.y)
|
||||
ctx.lineTo(to.x, to.y)
|
||||
ctx.moveTo(from.x, from.y);
|
||||
ctx.lineTo(to.x, to.y);
|
||||
}
|
||||
|
||||
ctx.stroke()
|
||||
ctx.stroke();
|
||||
}
|
|
@ -8,20 +8,22 @@ export type TDrawDetectionsInput = IRect | IBoundingBox | FaceDetection | WithFa
|
|||
|
||||
export function drawDetections(
|
||||
canvasArg: string | HTMLCanvasElement,
|
||||
detections: TDrawDetectionsInput | Array<TDrawDetectionsInput>
|
||||
detections: TDrawDetectionsInput | Array<TDrawDetectionsInput>,
|
||||
) {
|
||||
const detectionsArray = Array.isArray(detections) ? detections : [detections]
|
||||
const detectionsArray = Array.isArray(detections) ? detections : [detections];
|
||||
|
||||
detectionsArray.forEach(det => {
|
||||
detectionsArray.forEach((det) => {
|
||||
// eslint-disable-next-line no-nested-ternary
|
||||
const score = det instanceof FaceDetection
|
||||
? det.score
|
||||
: (isWithFaceDetection(det) ? det.detection.score : undefined)
|
||||
: (isWithFaceDetection(det) ? det.detection.score : undefined);
|
||||
|
||||
// eslint-disable-next-line no-nested-ternary
|
||||
const box = det instanceof FaceDetection
|
||||
? det.box
|
||||
: (isWithFaceDetection(det) ? det.detection.box : new Box(det))
|
||||
: (isWithFaceDetection(det) ? det.detection.box : new Box(det));
|
||||
|
||||
const label = score ? `${round(score)}` : undefined
|
||||
new DrawBox(box, { label }).draw(canvasArg)
|
||||
})
|
||||
const label = score ? `${round(score)}` : undefined;
|
||||
new DrawBox(box, { label }).draw(canvasArg);
|
||||
});
|
||||
}
|
|
@ -11,29 +11,30 @@ export function drawFaceExpressions(
|
|||
canvasArg: string | HTMLCanvasElement,
|
||||
faceExpressions: DrawFaceExpressionsInput | Array<DrawFaceExpressionsInput>,
|
||||
minConfidence = 0.1,
|
||||
textFieldAnchor?: IPoint
|
||||
textFieldAnchor?: IPoint,
|
||||
) {
|
||||
const faceExpressionsArray = Array.isArray(faceExpressions) ? faceExpressions : [faceExpressions]
|
||||
const faceExpressionsArray = Array.isArray(faceExpressions) ? faceExpressions : [faceExpressions];
|
||||
|
||||
faceExpressionsArray.forEach(e => {
|
||||
faceExpressionsArray.forEach((e) => {
|
||||
// eslint-disable-next-line no-nested-ternary
|
||||
const expr = e instanceof FaceExpressions
|
||||
? e
|
||||
: (isWithFaceExpressions(e) ? e.expressions : undefined)
|
||||
: (isWithFaceExpressions(e) ? e.expressions : undefined);
|
||||
if (!expr) {
|
||||
throw new Error('drawFaceExpressions - expected faceExpressions to be FaceExpressions | WithFaceExpressions<{}> or array thereof')
|
||||
throw new Error('drawFaceExpressions - expected faceExpressions to be FaceExpressions | WithFaceExpressions<{}> or array thereof');
|
||||
}
|
||||
|
||||
const sorted = expr.asSortedArray()
|
||||
const resultsToDisplay = sorted.filter(expr => expr.probability > minConfidence)
|
||||
const sorted = expr.asSortedArray();
|
||||
const resultsToDisplay = sorted.filter((exprLocal) => exprLocal.probability > minConfidence);
|
||||
|
||||
const anchor = isWithFaceDetection(e)
|
||||
? e.detection.box.bottomLeft
|
||||
: (textFieldAnchor || new Point(0, 0))
|
||||
: (textFieldAnchor || new Point(0, 0));
|
||||
|
||||
const drawTextField = new DrawTextField(
|
||||
resultsToDisplay.map(expr => `${expr.expression} (${round(expr.probability)})`),
|
||||
anchor
|
||||
)
|
||||
drawTextField.draw(canvasArg)
|
||||
})
|
||||
resultsToDisplay.map((exprLocal) => `${exprLocal.expression} (${round(exprLocal.probability)})`),
|
||||
anchor,
|
||||
);
|
||||
drawTextField.draw(canvasArg);
|
||||
});
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
export * from './drawContour'
|
||||
export * from './drawDetections'
|
||||
export * from './drawFaceExpressions'
|
||||
export * from './DrawBox'
|
||||
export * from './DrawFaceLandmarks'
|
||||
export * from './DrawTextField'
|
||||
export * from './drawContour';
|
||||
export * from './drawDetections';
|
||||
export * from './drawFaceExpressions';
|
||||
export * from './DrawBox';
|
||||
export * from './DrawFaceLandmarks';
|
||||
export * from './DrawTextField';
|
||||
|
|
|
@ -1,24 +1,22 @@
|
|||
import { Environment } from './types';
|
||||
|
||||
export function createBrowserEnv(): Environment {
|
||||
const fetch = window.fetch;
|
||||
if (!fetch) throw new Error('fetch - missing fetch implementation for browser environment');
|
||||
|
||||
const fetch = window['fetch'] || function() {
|
||||
throw new Error('fetch - missing fetch implementation for browser environment')
|
||||
}
|
||||
|
||||
const readFile = function() {
|
||||
throw new Error('readFile - filesystem not available for browser environment')
|
||||
}
|
||||
const readFile = () => {
|
||||
throw new Error('readFile - filesystem not available for browser environment');
|
||||
};
|
||||
|
||||
return {
|
||||
Canvas: HTMLCanvasElement,
|
||||
CanvasRenderingContext2D: CanvasRenderingContext2D,
|
||||
CanvasRenderingContext2D,
|
||||
Image: HTMLImageElement,
|
||||
ImageData: ImageData,
|
||||
ImageData,
|
||||
Video: HTMLVideoElement,
|
||||
createCanvasElement: () => document.createElement('canvas'),
|
||||
createImageElement: () => document.createElement('img'),
|
||||
fetch,
|
||||
readFile
|
||||
}
|
||||
readFile,
|
||||
};
|
||||
}
|
|
@ -1,30 +1,26 @@
|
|||
import { FileSystem } from './types';
|
||||
|
||||
export function createFileSystem(fs?: any): FileSystem {
|
||||
|
||||
let requireFsError = ''
|
||||
let requireFsError = '';
|
||||
|
||||
if (!fs) {
|
||||
try {
|
||||
fs = require('fs')
|
||||
// eslint-disable-next-line global-require
|
||||
fs = require('fs');
|
||||
} catch (err) {
|
||||
requireFsError = err.toString()
|
||||
requireFsError = err.toString();
|
||||
}
|
||||
}
|
||||
|
||||
const readFile = fs
|
||||
? function(filePath: string) {
|
||||
return new Promise<Buffer>((res, rej) => {
|
||||
fs.readFile(filePath, function(err: any, buffer: Buffer) {
|
||||
return err ? rej(err) : res(buffer)
|
||||
? (filePath: string) => new Promise<Buffer>((resolve, reject) => {
|
||||
fs.readFile(filePath, (err: any, buffer: Buffer) => (err ? reject(err) : resolve(buffer)));
|
||||
})
|
||||
})
|
||||
}
|
||||
: function() {
|
||||
throw new Error(`readFile - failed to require fs in nodejs environment with error: ${requireFsError}`)
|
||||
}
|
||||
: () => {
|
||||
throw new Error(`readFile - failed to require fs in nodejs environment with error: ${requireFsError}`);
|
||||
};
|
||||
|
||||
return {
|
||||
readFile
|
||||
}
|
||||
readFile,
|
||||
};
|
||||
}
|
|
@ -1,40 +1,36 @@
|
|||
/* eslint-disable max-classes-per-file */
|
||||
import { createFileSystem } from './createFileSystem';
|
||||
import { Environment } from './types';
|
||||
|
||||
export function createNodejsEnv(): Environment {
|
||||
// eslint-disable-next-line dot-notation
|
||||
const Canvas = global['Canvas'] || global.HTMLCanvasElement;
|
||||
const Image = global.Image || global.HTMLImageElement;
|
||||
|
||||
const Canvas = global['Canvas'] || global['HTMLCanvasElement']
|
||||
const Image = global['Image'] || global['HTMLImageElement']
|
||||
const createCanvasElement = () => {
|
||||
if (Canvas) return new Canvas();
|
||||
throw new Error('createCanvasElement - missing Canvas implementation for nodejs environment');
|
||||
};
|
||||
|
||||
const createCanvasElement = function() {
|
||||
if (Canvas) {
|
||||
return new Canvas()
|
||||
}
|
||||
throw new Error('createCanvasElement - missing Canvas implementation for nodejs environment')
|
||||
}
|
||||
const createImageElement = () => {
|
||||
if (Image) return new Image();
|
||||
throw new Error('createImageElement - missing Image implementation for nodejs environment');
|
||||
};
|
||||
|
||||
const createImageElement = function() {
|
||||
if (Image) {
|
||||
return new Image()
|
||||
}
|
||||
throw new Error('createImageElement - missing Image implementation for nodejs environment')
|
||||
}
|
||||
const fetch = global.fetch;
|
||||
// if (!fetch) throw new Error('fetch - missing fetch implementation for nodejs environment');
|
||||
|
||||
const fetch = global['fetch'] || function() {
|
||||
throw new Error('fetch - missing fetch implementation for nodejs environment')
|
||||
}
|
||||
|
||||
const fileSystem = createFileSystem()
|
||||
const fileSystem = createFileSystem();
|
||||
|
||||
return {
|
||||
Canvas: Canvas || class {},
|
||||
CanvasRenderingContext2D: global['CanvasRenderingContext2D'] || class {},
|
||||
CanvasRenderingContext2D: global.CanvasRenderingContext2D || class {},
|
||||
Image: Image || class {},
|
||||
ImageData: global['ImageData'] || class {},
|
||||
Video: global['HTMLVideoElement'] || class {},
|
||||
ImageData: global.ImageData || class {},
|
||||
Video: global.HTMLVideoElement || class {},
|
||||
createCanvasElement,
|
||||
createImageElement,
|
||||
fetch,
|
||||
...fileSystem
|
||||
}
|
||||
...fileSystem,
|
||||
};
|
||||
}
|
|
@ -5,49 +5,46 @@ import { isBrowser } from './isBrowser';
|
|||
import { isNodejs } from './isNodejs';
|
||||
import { Environment } from './types';
|
||||
|
||||
let environment: Environment | null
|
||||
let environment: Environment | null;
|
||||
|
||||
function getEnv(): Environment {
|
||||
if (!environment) {
|
||||
throw new Error('getEnv - environment is not defined, check isNodejs() and isBrowser()')
|
||||
throw new Error('getEnv - environment is not defined, check isNodejs() and isBrowser()');
|
||||
}
|
||||
return environment
|
||||
return environment;
|
||||
}
|
||||
|
||||
function setEnv(env: Environment) {
|
||||
environment = env
|
||||
environment = env;
|
||||
}
|
||||
|
||||
function initialize() {
|
||||
// check for isBrowser() first to prevent electron renderer process
|
||||
// to be initialized with wrong environment due to isNodejs() returning true
|
||||
if (isBrowser()) {
|
||||
return setEnv(createBrowserEnv())
|
||||
}
|
||||
if (isNodejs()) {
|
||||
return setEnv(createNodejsEnv())
|
||||
}
|
||||
if (isBrowser()) return setEnv(createBrowserEnv());
|
||||
if (isNodejs()) return setEnv(createNodejsEnv());
|
||||
return null;
|
||||
}
|
||||
|
||||
function monkeyPatch(env: Partial<Environment>) {
|
||||
if (!environment) {
|
||||
initialize()
|
||||
initialize();
|
||||
}
|
||||
|
||||
if (!environment) {
|
||||
throw new Error('monkeyPatch - environment is not defined, check isNodejs() and isBrowser()')
|
||||
throw new Error('monkeyPatch - environment is not defined, check isNodejs() and isBrowser()');
|
||||
}
|
||||
|
||||
const { Canvas = environment.Canvas, Image = environment.Image } = env
|
||||
environment.Canvas = Canvas
|
||||
environment.Image = Image
|
||||
environment.createCanvasElement = env.createCanvasElement || (() => new Canvas())
|
||||
environment.createImageElement = env.createImageElement || (() => new Image())
|
||||
const { Canvas = environment.Canvas, Image = environment.Image } = env;
|
||||
environment.Canvas = Canvas;
|
||||
environment.Image = Image;
|
||||
environment.createCanvasElement = env.createCanvasElement || (() => new Canvas());
|
||||
environment.createImageElement = env.createImageElement || (() => new Image());
|
||||
|
||||
environment.ImageData = env.ImageData || environment.ImageData
|
||||
environment.Video = env.Video || environment.Video
|
||||
environment.fetch = env.fetch || environment.fetch
|
||||
environment.readFile = env.readFile || environment.readFile
|
||||
environment.ImageData = env.ImageData || environment.ImageData;
|
||||
environment.Video = env.Video || environment.Video;
|
||||
environment.fetch = env.fetch || environment.fetch;
|
||||
environment.readFile = env.readFile || environment.readFile;
|
||||
}
|
||||
|
||||
export const env = {
|
||||
|
@ -59,9 +56,9 @@ export const env = {
|
|||
createNodejsEnv,
|
||||
monkeyPatch,
|
||||
isBrowser,
|
||||
isNodejs
|
||||
}
|
||||
isNodejs,
|
||||
};
|
||||
|
||||
initialize()
|
||||
initialize();
|
||||
|
||||
export * from './types'
|
||||
export * from './types';
|
||||
|
|
|
@ -5,5 +5,5 @@ export function isBrowser(): boolean {
|
|||
&& typeof HTMLCanvasElement !== 'undefined'
|
||||
&& typeof HTMLVideoElement !== 'undefined'
|
||||
&& typeof ImageData !== 'undefined'
|
||||
&& typeof CanvasRenderingContext2D !== 'undefined'
|
||||
&& typeof CanvasRenderingContext2D !== 'undefined';
|
||||
}
|
|
@ -4,5 +4,5 @@ export function isNodejs(): boolean {
|
|||
&& typeof module !== 'undefined'
|
||||
// issues with gatsby.js: module.exports is undefined
|
||||
// && !!module.exports
|
||||
&& typeof process !== 'undefined' && !!process.version
|
||||
&& typeof process !== 'undefined' && !!process.version;
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
export type FileSystem = {
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
readFile: (filePath: string) => Promise<Buffer>
|
||||
}
|
||||
|
||||
|
@ -10,5 +11,6 @@ export type Environment = FileSystem & {
|
|||
Video: typeof HTMLVideoElement
|
||||
createCanvasElement: () => HTMLCanvasElement
|
||||
createImageElement: () => HTMLImageElement
|
||||
// eslint-disable-next-line no-undef, no-unused-vars
|
||||
fetch: (url: string, init?: RequestInit) => Promise<Response>
|
||||
}
|
||||
|
|
|
@ -7,46 +7,45 @@ import { FaceProcessor } from '../faceProcessor/FaceProcessor';
|
|||
import { FaceExpressions } from './FaceExpressions';
|
||||
|
||||
export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> {
|
||||
|
||||
constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) {
|
||||
super('FaceExpressionNet', faceFeatureExtractor)
|
||||
super('FaceExpressionNet', faceFeatureExtractor);
|
||||
}
|
||||
|
||||
public forwardInput(input: NetInput | tf.Tensor4D): tf.Tensor2D {
|
||||
return tf.tidy(() => tf.softmax(this.runNet(input)))
|
||||
return tf.tidy(() => tf.softmax(this.runNet(input)));
|
||||
}
|
||||
|
||||
public async forward(input: TNetInput): Promise<tf.Tensor2D> {
|
||||
return this.forwardInput(await toNetInput(input))
|
||||
return this.forwardInput(await toNetInput(input));
|
||||
}
|
||||
|
||||
public async predictExpressions(input: TNetInput) {
|
||||
const netInput = await toNetInput(input)
|
||||
const out = await this.forwardInput(netInput)
|
||||
const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async t => {
|
||||
const data = await t.data()
|
||||
t.dispose()
|
||||
return data
|
||||
}))
|
||||
out.dispose()
|
||||
const netInput = await toNetInput(input);
|
||||
const out = await this.forwardInput(netInput);
|
||||
const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async (t) => {
|
||||
const data = await t.data();
|
||||
t.dispose();
|
||||
return data;
|
||||
}));
|
||||
out.dispose();
|
||||
|
||||
const predictionsByBatch = probabilitesByBatch
|
||||
.map(probabilites => new FaceExpressions(probabilites as Float32Array))
|
||||
.map((probabilites) => new FaceExpressions(probabilites as Float32Array));
|
||||
|
||||
return netInput.isBatchInput
|
||||
? predictionsByBatch
|
||||
: predictionsByBatch[0]
|
||||
: predictionsByBatch[0];
|
||||
}
|
||||
|
||||
protected getDefaultModelName(): string {
|
||||
return 'face_expression_model'
|
||||
return 'face_expression_model';
|
||||
}
|
||||
|
||||
protected getClassifierChannelsIn(): number {
|
||||
return 256
|
||||
return 256;
|
||||
}
|
||||
|
||||
protected getClassifierChannelsOut(): number {
|
||||
return 7
|
||||
return 7;
|
||||
}
|
||||
}
|
|
@ -1,27 +1,33 @@
|
|||
export const FACE_EXPRESSION_LABELS = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgusted', 'surprised']
|
||||
export const FACE_EXPRESSION_LABELS = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgusted', 'surprised'];
|
||||
|
||||
export class FaceExpressions {
|
||||
public neutral: number
|
||||
|
||||
public happy: number
|
||||
|
||||
public sad: number
|
||||
|
||||
public angry: number
|
||||
|
||||
public fearful: number
|
||||
|
||||
public disgusted: number
|
||||
|
||||
public surprised: number
|
||||
|
||||
constructor(probabilities: number[] | Float32Array) {
|
||||
if (probabilities.length !== 7) {
|
||||
throw new Error(`FaceExpressions.constructor - expected probabilities.length to be 7, have: ${probabilities.length}`)
|
||||
throw new Error(`FaceExpressions.constructor - expected probabilities.length to be 7, have: ${probabilities.length}`);
|
||||
}
|
||||
|
||||
FACE_EXPRESSION_LABELS.forEach((expression, idx) => {
|
||||
this[expression] = probabilities[idx]
|
||||
})
|
||||
this[expression] = probabilities[idx];
|
||||
});
|
||||
}
|
||||
|
||||
asSortedArray() {
|
||||
return FACE_EXPRESSION_LABELS
|
||||
.map(expression => ({ expression, probability: this[expression] as number }))
|
||||
.sort((e0, e1) => e1.probability - e0.probability)
|
||||
.map((expression) => ({ expression, probability: this[expression] as number }))
|
||||
.sort((e0, e1) => e1.probability - e0.probability);
|
||||
}
|
||||
}
|
|
@ -9,47 +9,45 @@ import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
|
|||
import { FaceFeatureExtractorParams, IFaceFeatureExtractor } from './types';
|
||||
|
||||
export class FaceFeatureExtractor extends NeuralNetwork<FaceFeatureExtractorParams> implements IFaceFeatureExtractor<FaceFeatureExtractorParams> {
|
||||
|
||||
constructor() {
|
||||
super('FaceFeatureExtractor')
|
||||
super('FaceFeatureExtractor');
|
||||
}
|
||||
|
||||
public forwardInput(input: NetInput): tf.Tensor4D {
|
||||
|
||||
const { params } = this
|
||||
const { params } = this;
|
||||
|
||||
if (!params) {
|
||||
throw new Error('FaceFeatureExtractor - load model before inference')
|
||||
throw new Error('FaceFeatureExtractor - load model before inference');
|
||||
}
|
||||
|
||||
return tf.tidy(() => {
|
||||
const batchTensor = tf.cast(input.toBatchTensor(112, true), 'float32');
|
||||
const meanRgb = [122.782, 117.001, 104.298]
|
||||
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D
|
||||
const meanRgb = [122.782, 117.001, 104.298];
|
||||
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D;
|
||||
|
||||
let out = denseBlock4(normalized, params.dense0, true)
|
||||
out = denseBlock4(out, params.dense1)
|
||||
out = denseBlock4(out, params.dense2)
|
||||
out = denseBlock4(out, params.dense3)
|
||||
out = tf.avgPool(out, [7, 7], [2, 2], 'valid')
|
||||
let out = denseBlock4(normalized, params.dense0, true);
|
||||
out = denseBlock4(out, params.dense1);
|
||||
out = denseBlock4(out, params.dense2);
|
||||
out = denseBlock4(out, params.dense3);
|
||||
out = tf.avgPool(out, [7, 7], [2, 2], 'valid');
|
||||
|
||||
return out
|
||||
})
|
||||
return out;
|
||||
});
|
||||
}
|
||||
|
||||
public async forward(input: TNetInput): Promise<tf.Tensor4D> {
|
||||
return this.forwardInput(await toNetInput(input))
|
||||
return this.forwardInput(await toNetInput(input));
|
||||
}
|
||||
|
||||
protected getDefaultModelName(): string {
|
||||
return 'face_feature_extractor_model'
|
||||
return 'face_feature_extractor_model';
|
||||
}
|
||||
|
||||
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
|
||||
return extractParamsFromWeigthMap(weightMap)
|
||||
return extractParamsFromWeigthMap(weightMap);
|
||||
}
|
||||
|
||||
protected extractParams(weights: Float32Array) {
|
||||
return extractParams(weights)
|
||||
return extractParams(weights);
|
||||
}
|
||||
}
|
|
@ -9,46 +9,44 @@ import { extractParamsTiny } from './extractParamsTiny';
|
|||
import { IFaceFeatureExtractor, TinyFaceFeatureExtractorParams } from './types';
|
||||
|
||||
export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyFaceFeatureExtractorParams> implements IFaceFeatureExtractor<TinyFaceFeatureExtractorParams> {
|
||||
|
||||
constructor() {
|
||||
super('TinyFaceFeatureExtractor')
|
||||
super('TinyFaceFeatureExtractor');
|
||||
}
|
||||
|
||||
public forwardInput(input: NetInput): tf.Tensor4D {
|
||||
|
||||
const { params } = this
|
||||
const { params } = this;
|
||||
|
||||
if (!params) {
|
||||
throw new Error('TinyFaceFeatureExtractor - load model before inference')
|
||||
throw new Error('TinyFaceFeatureExtractor - load model before inference');
|
||||
}
|
||||
|
||||
return tf.tidy(() => {
|
||||
const batchTensor = tf.cast(input.toBatchTensor(112, true), 'float32');
|
||||
const meanRgb = [122.782, 117.001, 104.298]
|
||||
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D
|
||||
const meanRgb = [122.782, 117.001, 104.298];
|
||||
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D;
|
||||
|
||||
let out = denseBlock3(normalized, params.dense0, true)
|
||||
out = denseBlock3(out, params.dense1)
|
||||
out = denseBlock3(out, params.dense2)
|
||||
out = tf.avgPool(out, [14, 14], [2, 2], 'valid')
|
||||
let out = denseBlock3(normalized, params.dense0, true);
|
||||
out = denseBlock3(out, params.dense1);
|
||||
out = denseBlock3(out, params.dense2);
|
||||
out = tf.avgPool(out, [14, 14], [2, 2], 'valid');
|
||||
|
||||
return out
|
||||
})
|
||||
return out;
|
||||
});
|
||||
}
|
||||
|
||||
public async forward(input: TNetInput): Promise<tf.Tensor4D> {
|
||||
return this.forwardInput(await toNetInput(input))
|
||||
return this.forwardInput(await toNetInput(input));
|
||||
}
|
||||
|
||||
protected getDefaultModelName(): string {
|
||||
return 'face_feature_extractor_tiny_model'
|
||||
return 'face_feature_extractor_tiny_model';
|
||||
}
|
||||
|
||||
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
|
||||
return extractParamsFromWeigthMapTiny(weightMap)
|
||||
return extractParamsFromWeigthMapTiny(weightMap);
|
||||
}
|
||||
|
||||
protected extractParams(weights: Float32Array) {
|
||||
return extractParamsTiny(weights)
|
||||
return extractParamsTiny(weights);
|
||||
}
|
||||
}
|
|
@ -7,49 +7,49 @@ import { DenseBlock3Params, DenseBlock4Params } from './types';
|
|||
export function denseBlock3(
|
||||
x: tf.Tensor4D,
|
||||
denseBlockParams: DenseBlock3Params,
|
||||
isFirstLayer: boolean = false
|
||||
isFirstLayer: boolean = false,
|
||||
): tf.Tensor4D {
|
||||
return tf.tidy(() => {
|
||||
const out1 = tf.relu(
|
||||
isFirstLayer
|
||||
? tf.add(
|
||||
tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, [2, 2], 'same'),
|
||||
denseBlockParams.conv0.bias
|
||||
denseBlockParams.conv0.bias,
|
||||
)
|
||||
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, [2, 2])
|
||||
) as tf.Tensor4D
|
||||
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1])
|
||||
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, [2, 2]),
|
||||
) as tf.Tensor4D;
|
||||
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]);
|
||||
|
||||
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D
|
||||
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1])
|
||||
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D;
|
||||
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1]);
|
||||
|
||||
return tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D
|
||||
})
|
||||
return tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D;
|
||||
});
|
||||
}
|
||||
|
||||
export function denseBlock4(
|
||||
x: tf.Tensor4D,
|
||||
denseBlockParams: DenseBlock4Params,
|
||||
isFirstLayer: boolean = false,
|
||||
isScaleDown: boolean = true
|
||||
isScaleDown: boolean = true,
|
||||
): tf.Tensor4D {
|
||||
return tf.tidy(() => {
|
||||
const out1 = tf.relu(
|
||||
isFirstLayer
|
||||
? tf.add(
|
||||
tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, isScaleDown ? [2, 2] : [1, 1], 'same'),
|
||||
denseBlockParams.conv0.bias
|
||||
denseBlockParams.conv0.bias,
|
||||
)
|
||||
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, isScaleDown ? [2, 2] : [1, 1])
|
||||
) as tf.Tensor4D
|
||||
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1])
|
||||
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, isScaleDown ? [2, 2] : [1, 1]),
|
||||
) as tf.Tensor4D;
|
||||
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]);
|
||||
|
||||
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D
|
||||
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1])
|
||||
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D;
|
||||
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1]);
|
||||
|
||||
const in4 = tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D
|
||||
const out4 = depthwiseSeparableConv(in4, denseBlockParams.conv3, [1, 1])
|
||||
const in4 = tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D;
|
||||
const out4 = depthwiseSeparableConv(in4, denseBlockParams.conv3, [1, 1]);
|
||||
|
||||
return tf.relu(tf.add(out1, tf.add(out2, tf.add(out3, out4)))) as tf.Tensor4D
|
||||
})
|
||||
return tf.relu(tf.add(out1, tf.add(out2, tf.add(out3, out4)))) as tf.Tensor4D;
|
||||
});
|
||||
}
|
||||
|
|
|
@ -2,31 +2,31 @@ import { extractWeightsFactory, ParamMapping } from '../common/index';
|
|||
import { extractorsFactory } from './extractorsFactory';
|
||||
import { FaceFeatureExtractorParams } from './types';
|
||||
|
||||
|
||||
export function extractParams(weights: Float32Array): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
|
||||
|
||||
const paramMappings: ParamMapping[] = []
|
||||
const paramMappings: ParamMapping[] = [];
|
||||
|
||||
const {
|
||||
extractWeights,
|
||||
getRemainingWeights
|
||||
} = extractWeightsFactory(weights)
|
||||
getRemainingWeights,
|
||||
} = extractWeightsFactory(weights);
|
||||
|
||||
const {
|
||||
extractDenseBlock4Params
|
||||
} = extractorsFactory(extractWeights, paramMappings)
|
||||
extractDenseBlock4Params,
|
||||
} = extractorsFactory(extractWeights, paramMappings);
|
||||
|
||||
const dense0 = extractDenseBlock4Params(3, 32, 'dense0', true)
|
||||
const dense1 = extractDenseBlock4Params(32, 64, 'dense1')
|
||||
const dense2 = extractDenseBlock4Params(64, 128, 'dense2')
|
||||
const dense3 = extractDenseBlock4Params(128, 256, 'dense3')
|
||||
const dense0 = extractDenseBlock4Params(3, 32, 'dense0', true);
|
||||
const dense1 = extractDenseBlock4Params(32, 64, 'dense1');
|
||||
const dense2 = extractDenseBlock4Params(64, 128, 'dense2');
|
||||
const dense3 = extractDenseBlock4Params(128, 256, 'dense3');
|
||||
|
||||
if (getRemainingWeights().length !== 0) {
|
||||
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
|
||||
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
|
||||
}
|
||||
|
||||
return {
|
||||
paramMappings,
|
||||
params: { dense0, dense1, dense2, dense3 }
|
||||
}
|
||||
params: {
|
||||
dense0, dense1, dense2, dense3,
|
||||
},
|
||||
};
|
||||
}
|
|
@ -5,23 +5,22 @@ import { loadParamsFactory } from './loadParamsFactory';
|
|||
import { FaceFeatureExtractorParams } from './types';
|
||||
|
||||
export function extractParamsFromWeigthMap(
|
||||
weightMap: tf.NamedTensorMap
|
||||
weightMap: tf.NamedTensorMap,
|
||||
): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
|
||||
|
||||
const paramMappings: ParamMapping[] = []
|
||||
const paramMappings: ParamMapping[] = [];
|
||||
|
||||
const {
|
||||
extractDenseBlock4Params
|
||||
} = loadParamsFactory(weightMap, paramMappings)
|
||||
extractDenseBlock4Params,
|
||||
} = loadParamsFactory(weightMap, paramMappings);
|
||||
|
||||
const params = {
|
||||
dense0: extractDenseBlock4Params('dense0', true),
|
||||
dense1: extractDenseBlock4Params('dense1'),
|
||||
dense2: extractDenseBlock4Params('dense2'),
|
||||
dense3: extractDenseBlock4Params('dense3')
|
||||
}
|
||||
dense3: extractDenseBlock4Params('dense3'),
|
||||
};
|
||||
|
||||
disposeUnusedWeightTensors(weightMap, paramMappings)
|
||||
disposeUnusedWeightTensors(weightMap, paramMappings);
|
||||
|
||||
return { params, paramMappings }
|
||||
return { params, paramMappings };
|
||||
}
|
|
@ -5,22 +5,21 @@ import { loadParamsFactory } from './loadParamsFactory';
|
|||
import { TinyFaceFeatureExtractorParams } from './types';
|
||||
|
||||
export function extractParamsFromWeigthMapTiny(
|
||||
weightMap: tf.NamedTensorMap
|
||||
weightMap: tf.NamedTensorMap,
|
||||
): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
|
||||
|
||||
const paramMappings: ParamMapping[] = []
|
||||
const paramMappings: ParamMapping[] = [];
|
||||
|
||||
const {
|
||||
extractDenseBlock3Params
|
||||
} = loadParamsFactory(weightMap, paramMappings)
|
||||
extractDenseBlock3Params,
|
||||
} = loadParamsFactory(weightMap, paramMappings);
|
||||
|
||||
const params = {
|
||||
dense0: extractDenseBlock3Params('dense0', true),
|
||||
dense1: extractDenseBlock3Params('dense1'),
|
||||
dense2: extractDenseBlock3Params('dense2')
|
||||
}
|
||||
dense2: extractDenseBlock3Params('dense2'),
|
||||
};
|
||||
|
||||
disposeUnusedWeightTensors(weightMap, paramMappings)
|
||||
disposeUnusedWeightTensors(weightMap, paramMappings);
|
||||
|
||||
return { params, paramMappings }
|
||||
return { params, paramMappings };
|
||||
}
|
|
@ -2,31 +2,28 @@ import { extractWeightsFactory, ParamMapping } from '../common/index';
|
|||
import { extractorsFactory } from './extractorsFactory';
|
||||
import { TinyFaceFeatureExtractorParams } from './types';
|
||||
|
||||
|
||||
|
||||
export function extractParamsTiny(weights: Float32Array): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
|
||||
|
||||
const paramMappings: ParamMapping[] = []
|
||||
const paramMappings: ParamMapping[] = [];
|
||||
|
||||
const {
|
||||
extractWeights,
|
||||
getRemainingWeights
|
||||
} = extractWeightsFactory(weights)
|
||||
getRemainingWeights,
|
||||
} = extractWeightsFactory(weights);
|
||||
|
||||
const {
|
||||
extractDenseBlock3Params
|
||||
} = extractorsFactory(extractWeights, paramMappings)
|
||||
extractDenseBlock3Params,
|
||||
} = extractorsFactory(extractWeights, paramMappings);
|
||||
|
||||
const dense0 = extractDenseBlock3Params(3, 32, 'dense0', true)
|
||||
const dense1 = extractDenseBlock3Params(32, 64, 'dense1')
|
||||
const dense2 = extractDenseBlock3Params(64, 128, 'dense2')
|
||||
const dense0 = extractDenseBlock3Params(3, 32, 'dense0', true);
|
||||
const dense1 = extractDenseBlock3Params(32, 64, 'dense1');
|
||||
const dense2 = extractDenseBlock3Params(64, 128, 'dense2');
|
||||
|
||||
if (getRemainingWeights().length !== 0) {
|
||||
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
|
||||
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
|
||||
}
|
||||
|
||||
return {
|
||||
paramMappings,
|
||||
params: { dense0, dense1, dense2 }
|
||||
}
|
||||
params: { dense0, dense1, dense2 },
|
||||
};
|
||||
}
|
|
@ -7,32 +7,30 @@ import {
|
|||
import { DenseBlock3Params, DenseBlock4Params } from './types';
|
||||
|
||||
export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
|
||||
|
||||
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
|
||||
const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings)
|
||||
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings);
|
||||
const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings);
|
||||
|
||||
function extractDenseBlock3Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock3Params {
|
||||
|
||||
const conv0 = isFirstLayer
|
||||
? extractConvParams(channelsIn, channelsOut, 3, `${mappedPrefix}/conv0`)
|
||||
: extractSeparableConvParams(channelsIn, channelsOut, `${mappedPrefix}/conv0`)
|
||||
const conv1 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv1`)
|
||||
const conv2 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv2`)
|
||||
: extractSeparableConvParams(channelsIn, channelsOut, `${mappedPrefix}/conv0`);
|
||||
const conv1 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv1`);
|
||||
const conv2 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv2`);
|
||||
|
||||
return { conv0, conv1, conv2 }
|
||||
return { conv0, conv1, conv2 };
|
||||
}
|
||||
|
||||
function extractDenseBlock4Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock4Params {
|
||||
const { conv0, conv1, conv2 } = extractDenseBlock3Params(channelsIn, channelsOut, mappedPrefix, isFirstLayer);
|
||||
const conv3 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv3`);
|
||||
|
||||
const { conv0, conv1, conv2 } = extractDenseBlock3Params(channelsIn, channelsOut, mappedPrefix, isFirstLayer)
|
||||
const conv3 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv3`)
|
||||
|
||||
return { conv0, conv1, conv2, conv3 }
|
||||
return {
|
||||
conv0, conv1, conv2, conv3,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
extractDenseBlock3Params,
|
||||
extractDenseBlock4Params
|
||||
}
|
||||
|
||||
extractDenseBlock4Params,
|
||||
};
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue