full re-lint and typings generation

pull/34/head
Vladimir Mandic 2020-12-23 11:26:55 -05:00
parent 2a7bf4080b
commit 0692fbf532
353 changed files with 7151 additions and 5729 deletions

47
.eslintrc.json Normal file
View File

@ -0,0 +1,47 @@
{
"globals": {},
"env": {
"browser": true,
"commonjs": true,
"es6": true,
"node": true,
"es2020": true
},
"parser": "@typescript-eslint/parser",
"parserOptions": { "ecmaVersion": 2020 },
"plugins": ["@typescript-eslint"],
"extends": [
"eslint:recommended",
"plugin:import/errors",
"plugin:import/warnings",
"plugin:import/typescript",
"plugin:node/recommended",
"plugin:promise/recommended",
"plugin:json/recommended-with-comments",
"airbnb-base"
],
"ignorePatterns": [ "node_modules", "types" ],
"settings": {
"import/resolver": {
"node": {
"extensions": [".js", ".ts"]
}
}
},
"rules": {
"max-len": [1, 275, 3],
"no-plusplus": "off",
"import/prefer-default-export": "off",
"node/no-unsupported-features/es-syntax": "off",
"import/no-cycle": "off",
"import/extensions": "off",
"node/no-missing-import": "off",
"no-underscore-dangle": "off",
"class-methods-use-this": "off",
"camelcase": "off",
"no-await-in-loop": "off",
"no-continue": "off",
"no-param-reassign": "off",
"prefer-destructuring": "off"
}
}

View File

@ -21,7 +21,8 @@ const tsconfig = {
noEmitOnError: false,
target: ts.ScriptTarget.ES2018,
module: ts.ModuleKind.ES2020,
outFile: "dist/face-api.d.ts",
// outFile: "dist/face-api.d.ts",
outDir: "types/",
declaration: true,
emitDeclarationOnly: true,
emitDecoratorMetadata: true,

2109
dist/face-api.d.ts vendored

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

401
dist/face-api.esm.json vendored

File diff suppressed because it is too large Load Diff

4
dist/face-api.js vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

400
dist/face-api.json vendored

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

4
dist/tfjs.esm.js vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

13
dist/tfjs.esm.json vendored
View File

@ -20799,7 +20799,7 @@
]
},
"src/tfjs/tf-browser.js": {
"bytes": 1784,
"bytes": 1888,
"imports": [
{
"path": "node_modules/@tensorflow/tfjs/dist/index.js"
@ -20814,7 +20814,7 @@
"dist/tfjs.esm.js.map": {
"imports": [],
"inputs": {},
"bytes": 1070617
"bytes": 1063970
},
"dist/tfjs.esm.js": {
"imports": [],
@ -23195,7 +23195,7 @@
"bytesInOutput": 29890
},
"node_modules/@tensorflow/tfjs-layers/dist/layers/convolutional_recurrent.js": {
"bytesInOutput": 9235
"bytesInOutput": 9196
},
"node_modules/@tensorflow/tfjs-layers/dist/layers/core.js": {
"bytesInOutput": 9961
@ -24070,11 +24070,8 @@
"node_modules/@tensorflow/tfjs-backend-webgl/dist/version.js": {
"bytesInOutput": 22
},
"node_modules/@tensorflow/tfjs-backend-webgl/dist/webgl.js": {
"bytesInOutput": 67
},
"node_modules/@tensorflow/tfjs-backend-webgl/dist/base.js": {
"bytesInOutput": 113
"bytesInOutput": 85
},
"node_modules/@tensorflow/tfjs-backend-webgl/dist/index.js": {
"bytesInOutput": 0
@ -25067,7 +25064,7 @@
"bytesInOutput": 0
}
},
"bytes": 1572551
"bytes": 1572418
}
}
}

1740
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -5,13 +5,14 @@
"main": "dist/face-api.node.js",
"module": "dist/face-api.esm.js",
"browser": "dist/face-api.esm.js",
"typings": "dist/face-api.d.ts",
"types": "types/index.d.ts",
"engines": {
"node": ">=12.0.0"
},
"scripts": {
"start": "node --trace-warnings example/node.js",
"build": "rimraf dist/* && node ./build.js"
"start": "node --trace-warnings example/node-singleprocess.js",
"build": "rimraf dist/* && node ./build.js",
"lint": "eslint src/**/*"
},
"keywords": [
"tensorflow",
@ -44,7 +45,15 @@
"@tensorflow/tfjs-node": "^2.8.1",
"@tensorflow/tfjs-node-gpu": "^2.8.1",
"@types/node": "^14.14.14",
"esbuild": "^0.8.24",
"@typescript-eslint/eslint-plugin": "^4.11.0",
"@typescript-eslint/parser": "^4.11.0",
"esbuild": "^0.8.26",
"eslint": "^7.16.0",
"eslint-config-airbnb-base": "^14.2.1",
"eslint-plugin-import": "^2.22.1",
"eslint-plugin-json": "^2.1.2",
"eslint-plugin-node": "^11.1.0",
"eslint-plugin-promise": "^4.2.1",
"rimraf": "^3.0.2",
"tslib": "^2.0.3",
"typescript": "^4.1.3"

View File

@ -5,120 +5,118 @@ import { seperateWeightMaps } from '../faceProcessor/util';
import { TinyXception } from '../xception/TinyXception';
import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { AgeAndGenderPrediction, Gender, NetOutput, NetParams } from './types';
import {
AgeAndGenderPrediction, Gender, NetOutput, NetParams,
} from './types';
import { NeuralNetwork } from '../NeuralNetwork';
import { NetInput, TNetInput, toNetInput } from '../dom/index';
export class AgeGenderNet extends NeuralNetwork<NetParams> {
private _faceFeatureExtractor: TinyXception
constructor(faceFeatureExtractor: TinyXception = new TinyXception(2)) {
super('AgeGenderNet')
this._faceFeatureExtractor = faceFeatureExtractor
super('AgeGenderNet');
this._faceFeatureExtractor = faceFeatureExtractor;
}
public get faceFeatureExtractor(): TinyXception {
return this._faceFeatureExtractor
return this._faceFeatureExtractor;
}
public runNet(input: NetInput | tf.Tensor4D): NetOutput {
const { params } = this
const { params } = this;
if (!params) {
throw new Error(`${this._name} - load model before inference`)
throw new Error(`${this._name} - load model before inference`);
}
return tf.tidy(() => {
const bottleneckFeatures = input instanceof NetInput
? this.faceFeatureExtractor.forwardInput(input)
: input
: input;
const pooled = tf.avgPool(bottleneckFeatures, [7, 7], [2, 2], 'valid').as2D(bottleneckFeatures.shape[0], -1)
const age = fullyConnectedLayer(pooled, params.fc.age).as1D()
const gender = fullyConnectedLayer(pooled, params.fc.gender)
return { age, gender }
})
const pooled = tf.avgPool(bottleneckFeatures, [7, 7], [2, 2], 'valid').as2D(bottleneckFeatures.shape[0], -1);
const age = fullyConnectedLayer(pooled, params.fc.age).as1D();
const gender = fullyConnectedLayer(pooled, params.fc.gender);
return { age, gender };
});
}
public forwardInput(input: NetInput | tf.Tensor4D): NetOutput {
return tf.tidy(() => {
const { age, gender } = this.runNet(input)
return { age, gender: tf.softmax(gender) }
})
const { age, gender } = this.runNet(input);
return { age, gender: tf.softmax(gender) };
});
}
public async forward(input: TNetInput): Promise<NetOutput> {
return this.forwardInput(await toNetInput(input))
return this.forwardInput(await toNetInput(input));
}
public async predictAgeAndGender(input: TNetInput): Promise<AgeAndGenderPrediction | AgeAndGenderPrediction[]> {
const netInput = await toNetInput(input)
const out = await this.forwardInput(netInput)
const netInput = await toNetInput(input);
const out = await this.forwardInput(netInput);
const ages = tf.unstack(out.age)
const genders = tf.unstack(out.gender)
const ages = tf.unstack(out.age);
const genders = tf.unstack(out.gender);
const ageAndGenderTensors = ages.map((ageTensor, i) => ({
ageTensor,
genderTensor: genders[i]
}))
genderTensor: genders[i],
}));
const predictionsByBatch = await Promise.all(
ageAndGenderTensors.map(async ({ ageTensor, genderTensor }) => {
const age = (await ageTensor.data())[0]
const probMale = (await genderTensor.data())[0]
const isMale = probMale > 0.5
const gender = isMale ? Gender.MALE : Gender.FEMALE
const genderProbability = isMale ? probMale : (1 - probMale)
const age = (await ageTensor.data())[0];
const probMale = (await genderTensor.data())[0];
const isMale = probMale > 0.5;
const gender = isMale ? Gender.MALE : Gender.FEMALE;
const genderProbability = isMale ? probMale : (1 - probMale);
ageTensor.dispose()
genderTensor.dispose()
return { age, gender, genderProbability }
})
)
out.age.dispose()
out.gender.dispose()
ageTensor.dispose();
genderTensor.dispose();
return { age, gender, genderProbability };
}),
);
out.age.dispose();
out.gender.dispose();
return netInput.isBatchInput ? predictionsByBatch as AgeAndGenderPrediction[] : predictionsByBatch[0] as AgeAndGenderPrediction
return netInput.isBatchInput ? predictionsByBatch as AgeAndGenderPrediction[] : predictionsByBatch[0] as AgeAndGenderPrediction;
}
protected getDefaultModelName(): string {
return 'age_gender_model'
return 'age_gender_model';
}
public dispose(throwOnRedispose: boolean = true) {
this.faceFeatureExtractor.dispose(throwOnRedispose)
super.dispose(throwOnRedispose)
this.faceFeatureExtractor.dispose(throwOnRedispose);
super.dispose(throwOnRedispose);
}
public loadClassifierParams(weights: Float32Array) {
const { params, paramMappings } = this.extractClassifierParams(weights)
this._params = params
this._paramMappings = paramMappings
const { params, paramMappings } = this.extractClassifierParams(weights);
this._params = params;
this._paramMappings = paramMappings;
}
public extractClassifierParams(weights: Float32Array) {
return extractParams(weights)
return extractParams(weights);
}
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap);
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap);
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap)
return extractParamsFromWeigthMap(classifierMap)
return extractParamsFromWeigthMap(classifierMap);
}
protected extractParams(weights: Float32Array) {
const classifierWeightSize = (512 * 1 + 1) + (512 * 2 + 2);
const classifierWeightSize = (512 * 1 + 1) + (512 * 2 + 2)
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize);
const classifierWeights = weights.slice(weights.length - classifierWeightSize);
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
return this.extractClassifierParams(classifierWeights)
this.faceFeatureExtractor.extractWeights(featureExtractorWeights);
return this.extractClassifierParams(classifierWeights);
}
}

View File

@ -2,25 +2,24 @@ import { extractFCParamsFactory, extractWeightsFactory, ParamMapping } from '../
import { NetParams } from './types';
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const paramMappings: ParamMapping[] = [];
const {
extractWeights,
getRemainingWeights
} = extractWeightsFactory(weights)
getRemainingWeights,
} = extractWeightsFactory(weights);
const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings)
const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings);
const age = extractFCParams(512, 1, 'fc/age')
const gender = extractFCParams(512, 2, 'fc/gender')
const age = extractFCParams(512, 1, 'fc/age');
const gender = extractFCParams(512, 2, 'fc/gender');
if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
}
return {
paramMappings,
params: { fc: { age, gender } }
}
params: { fc: { age, gender } },
};
}

View File

@ -1,30 +1,31 @@
import * as tf from '../../dist/tfjs.esm.js';
import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index';
import {
disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping,
} from '../common/index';
import { NetParams } from './types';
export function extractParamsFromWeigthMap(
weightMap: tf.NamedTensorMap
weightMap: tf.NamedTensorMap,
): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [];
const paramMappings: ParamMapping[] = []
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
function extractFcParams(prefix: string): FCParams {
const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return { weights, bias }
const weights = extractWeightEntry(`${prefix}/weights`, 2);
const bias = extractWeightEntry(`${prefix}/bias`, 1);
return { weights, bias };
}
const params = {
fc: {
age: extractFcParams('fc/age'),
gender: extractFcParams('fc/gender')
}
}
gender: extractFcParams('fc/gender'),
},
};
disposeUnusedWeightTensors(weightMap, paramMappings)
disposeUnusedWeightTensors(weightMap, paramMappings);
return { params, paramMappings }
return { params, paramMappings };
}

View File

@ -2,17 +2,20 @@ import * as tf from '../../dist/tfjs.esm.js';
import { FCParams } from '../common/index';
// eslint-disable-next-line no-shadow
export enum Gender {
// eslint-disable-next-line no-unused-vars
FEMALE = 'female',
// eslint-disable-next-line no-unused-vars
MALE = 'male'
}
export type AgeAndGenderPrediction = {
age: number
gender: Gender
genderProbability: number
}
export enum Gender {
FEMALE = 'female',
MALE = 'male'
}
export type NetOutput = { age: tf.Tensor1D, gender: tf.Tensor2D }
export type NetParams = {

View File

@ -9,6 +9,8 @@ export interface IBoundingBox {
export class BoundingBox extends Box<BoundingBox> implements IBoundingBox {
constructor(left: number, top: number, right: number, bottom: number, allowNegativeDimensions: boolean = false) {
super({ left, top, right, bottom }, allowNegativeDimensions)
super({
left, top, right, bottom,
}, allowNegativeDimensions);
}
}

View File

@ -5,163 +5,197 @@ import { Point } from './Point';
import { IRect } from './Rect';
export class Box<BoxType = any> implements IBoundingBox, IRect {
public static isRect(rect: any): boolean {
return !!rect && [rect.x, rect.y, rect.width, rect.height].every(isValidNumber)
return !!rect && [rect.x, rect.y, rect.width, rect.height].every(isValidNumber);
}
public static assertIsValidBox(box: any, callee: string, allowNegativeDimensions: boolean = false) {
if (!Box.isRect(box)) {
throw new Error(`${callee} - invalid box: ${JSON.stringify(box)}, expected object with properties x, y, width, height`)
throw new Error(`${callee} - invalid box: ${JSON.stringify(box)}, expected object with properties x, y, width, height`);
}
if (!allowNegativeDimensions && (box.width < 0 || box.height < 0)) {
throw new Error(`${callee} - width (${box.width}) and height (${box.height}) must be positive numbers`)
throw new Error(`${callee} - width (${box.width}) and height (${box.height}) must be positive numbers`);
}
}
private _x: number
private _y: number
private _width: number
private _height: number
constructor(_box: IBoundingBox | IRect, allowNegativeDimensions: boolean = true) {
const box = (_box || {}) as any
const box = (_box || {}) as any;
const isBbox = [box.left, box.top, box.right, box.bottom].every(isValidNumber)
const isRect = [box.x, box.y, box.width, box.height].every(isValidNumber)
const isBbox = [box.left, box.top, box.right, box.bottom].every(isValidNumber);
const isRect = [box.x, box.y, box.width, box.height].every(isValidNumber);
if (!isRect && !isBbox) {
throw new Error(`Box.constructor - expected box to be IBoundingBox | IRect, instead have ${JSON.stringify(box)}`)
throw new Error(`Box.constructor - expected box to be IBoundingBox | IRect, instead have ${JSON.stringify(box)}`);
}
const [x, y, width, height] = isRect
? [box.x, box.y, box.width, box.height]
: [box.left, box.top, box.right - box.left, box.bottom - box.top]
: [box.left, box.top, box.right - box.left, box.bottom - box.top];
Box.assertIsValidBox({ x, y, width, height }, 'Box.constructor', allowNegativeDimensions)
Box.assertIsValidBox({
x, y, width, height,
}, 'Box.constructor', allowNegativeDimensions);
this._x = x
this._y = y
this._width = width
this._height = height
this._x = x;
this._y = y;
this._width = width;
this._height = height;
}
public get x(): number { return this._x }
public get y(): number { return this._y }
public get width(): number { return this._width }
public get height(): number { return this._height }
public get left(): number { return this.x }
public get top(): number { return this.y }
public get right(): number { return this.x + this.width }
public get bottom(): number { return this.y + this.height }
public get area(): number { return this.width * this.height }
public get topLeft(): Point { return new Point(this.left, this.top) }
public get topRight(): Point { return new Point(this.right, this.top) }
public get bottomLeft(): Point { return new Point(this.left, this.bottom) }
public get bottomRight(): Point { return new Point(this.right, this.bottom) }
public get x(): number { return this._x; }
public get y(): number { return this._y; }
public get width(): number { return this._width; }
public get height(): number { return this._height; }
public get left(): number { return this.x; }
public get top(): number { return this.y; }
public get right(): number { return this.x + this.width; }
public get bottom(): number { return this.y + this.height; }
public get area(): number { return this.width * this.height; }
public get topLeft(): Point { return new Point(this.left, this.top); }
public get topRight(): Point { return new Point(this.right, this.top); }
public get bottomLeft(): Point { return new Point(this.left, this.bottom); }
public get bottomRight(): Point { return new Point(this.right, this.bottom); }
public round(): Box<BoxType> {
const [x, y, width, height] = [this.x, this.y, this.width, this.height]
.map(val => Math.round(val))
return new Box({ x, y, width, height })
.map((val) => Math.round(val));
return new Box({
x, y, width, height,
});
}
public floor(): Box<BoxType> {
const [x, y, width, height] = [this.x, this.y, this.width, this.height]
.map(val => Math.floor(val))
return new Box({ x, y, width, height })
.map((val) => Math.floor(val));
return new Box({
x, y, width, height,
});
}
public toSquare(): Box<BoxType> {
let { x, y, width, height } = this
const diff = Math.abs(width - height)
let {
x, y, width, height,
} = this;
const diff = Math.abs(width - height);
if (width < height) {
x -= (diff / 2)
width += diff
x -= (diff / 2);
width += diff;
}
if (height < width) {
y -= (diff / 2)
height += diff
y -= (diff / 2);
height += diff;
}
return new Box({ x, y, width, height })
return new Box({
x, y, width, height,
});
}
public rescale(s: IDimensions | number): Box<BoxType> {
const scaleX = isDimensions(s) ? (s as IDimensions).width : s as number
const scaleY = isDimensions(s) ? (s as IDimensions).height : s as number
const scaleX = isDimensions(s) ? (s as IDimensions).width : s as number;
const scaleY = isDimensions(s) ? (s as IDimensions).height : s as number;
return new Box({
x: this.x * scaleX,
y: this.y * scaleY,
width: this.width * scaleX,
height: this.height * scaleY
})
height: this.height * scaleY,
});
}
public pad(padX: number, padY: number): Box<BoxType> {
let [x, y, width, height] = [
const [x, y, width, height] = [
this.x - (padX / 2),
this.y - (padY / 2),
this.width + padX,
this.height + padY
]
return new Box({ x, y, width, height })
this.height + padY,
];
return new Box({
x, y, width, height,
});
}
public clipAtImageBorders(imgWidth: number, imgHeight: number): Box<BoxType> {
const { x, y, right, bottom } = this
const clippedX = Math.max(x, 0)
const clippedY = Math.max(y, 0)
const {
x, y, right, bottom,
} = this;
const clippedX = Math.max(x, 0);
const clippedY = Math.max(y, 0);
const newWidth = right - clippedX
const newHeight = bottom - clippedY
const clippedWidth = Math.min(newWidth, imgWidth - clippedX)
const clippedHeight = Math.min(newHeight, imgHeight - clippedY)
const newWidth = right - clippedX;
const newHeight = bottom - clippedY;
const clippedWidth = Math.min(newWidth, imgWidth - clippedX);
const clippedHeight = Math.min(newHeight, imgHeight - clippedY);
return (new Box({ x: clippedX, y: clippedY, width: clippedWidth, height: clippedHeight})).floor()
return (new Box({
x: clippedX, y: clippedY, width: clippedWidth, height: clippedHeight,
})).floor();
}
public shift(sx: number, sy: number): Box<BoxType> {
const { width, height } = this
const x = this.x + sx
const y = this.y + sy
const { width, height } = this;
const x = this.x + sx;
const y = this.y + sy;
return new Box({ x, y, width, height })
return new Box({
x, y, width, height,
});
}
public padAtBorders(imageHeight: number, imageWidth: number) {
const w = this.width + 1
const h = this.height + 1
const w = this.width + 1;
const h = this.height + 1;
let dx = 1
let dy = 1
let edx = w
let edy = h
const dx = 1;
const dy = 1;
let edx = w;
let edy = h;
let x = this.left
let y = this.top
let ex = this.right
let ey = this.bottom
let x = this.left;
let y = this.top;
let ex = this.right;
let ey = this.bottom;
if (ex > imageWidth) {
edx = -ex + imageWidth + w
ex = imageWidth
edx = -ex + imageWidth + w;
ex = imageWidth;
}
if (ey > imageHeight) {
edy = -ey + imageHeight + h
ey = imageHeight
edy = -ey + imageHeight + h;
ey = imageHeight;
}
if (x < 1) {
edy = 2 - x
x = 1
edy = 2 - x;
x = 1;
}
if (y < 1) {
edy = 2 - y
y = 1
edy = 2 - y;
y = 1;
}
return { dy, edy, dx, edx, y, ey, x, ex, w, h }
return {
dy, edy, dx, edx, y, ey, x, ex, w, h,
};
}
public calibrate(region: Box) {
@ -169,7 +203,7 @@ export class Box<BoxType = any> implements IBoundingBox, IRect {
left: this.left + (region.left * this.width),
top: this.top + (region.top * this.height),
right: this.right + (region.right * this.width),
bottom: this.bottom + (region.bottom * this.height)
}).toSquare().round()
bottom: this.bottom + (region.bottom * this.height),
}).toSquare().round();
}
}

View File

@ -6,23 +6,24 @@ export interface IDimensions {
}
export class Dimensions implements IDimensions {
private _width: number
private _height: number
constructor(width: number, height: number) {
if (!isValidNumber(width) || !isValidNumber(height)) {
throw new Error(`Dimensions.constructor - expected width and height to be valid numbers, instead have ${JSON.stringify({ width, height })}`)
throw new Error(`Dimensions.constructor - expected width and height to be valid numbers, instead have ${JSON.stringify({ width, height })}`);
}
this._width = width
this._height = height
this._width = width;
this._height = height;
}
public get width(): number { return this._width }
public get height(): number { return this._height }
public get width(): number { return this._width; }
public get height(): number { return this._height; }
public reverse(): Dimensions {
return new Dimensions(1 / this.width, 1 / this.height)
return new Dimensions(1 / this.width, 1 / this.height);
}
}

View File

@ -12,13 +12,13 @@ export class FaceDetection extends ObjectDetection implements IFaceDetecion {
constructor(
score: number,
relativeBox: Rect,
imageDims: IDimensions
imageDims: IDimensions,
) {
super(score, score, '', relativeBox, imageDims)
super(score, score, '', relativeBox, imageDims);
}
public forSize(width: number, height: number): FaceDetection {
const { score, relativeBox, imageDims } = super.forSize(width, height)
return new FaceDetection(score, relativeBox, imageDims)
const { score, relativeBox, imageDims } = super.forSize(width, height);
return new FaceDetection(score, relativeBox, imageDims);
}
}

View File

@ -8,9 +8,9 @@ import { Point } from './Point';
import { IRect, Rect } from './Rect';
// face alignment constants
const relX = 0.5
const relY = 0.43
const relScale = 0.45
const relX = 0.5;
const relY = 0.43;
const relScale = 0.45;
export interface IFaceLandmarks {
positions: Point[]
@ -19,49 +19,55 @@ export interface IFaceLandmarks {
export class FaceLandmarks implements IFaceLandmarks {
protected _shift: Point
protected _positions: Point[]
protected _imgDims: Dimensions
constructor(
relativeFaceLandmarkPositions: Point[],
imgDims: IDimensions,
shift: Point = new Point(0, 0)
shift: Point = new Point(0, 0),
) {
const { width, height } = imgDims
this._imgDims = new Dimensions(width, height)
this._shift = shift
const { width, height } = imgDims;
this._imgDims = new Dimensions(width, height);
this._shift = shift;
this._positions = relativeFaceLandmarkPositions.map(
pt => pt.mul(new Point(width, height)).add(shift)
)
(pt) => pt.mul(new Point(width, height)).add(shift),
);
}
public get shift(): Point { return new Point(this._shift.x, this._shift.y) }
public get imageWidth(): number { return this._imgDims.width }
public get imageHeight(): number { return this._imgDims.height }
public get positions(): Point[] { return this._positions }
public get shift(): Point { return new Point(this._shift.x, this._shift.y); }
public get imageWidth(): number { return this._imgDims.width; }
public get imageHeight(): number { return this._imgDims.height; }
public get positions(): Point[] { return this._positions; }
public get relativePositions(): Point[] {
return this._positions.map(
pt => pt.sub(this._shift).div(new Point(this.imageWidth, this.imageHeight))
)
(pt) => pt.sub(this._shift).div(new Point(this.imageWidth, this.imageHeight)),
);
}
public forSize<T extends FaceLandmarks>(width: number, height: number): T {
return new (this.constructor as any)(
this.relativePositions,
{ width, height }
)
{ width, height },
);
}
public shiftBy<T extends FaceLandmarks>(x: number, y: number): T {
return new (this.constructor as any)(
this.relativePositions,
this._imgDims,
new Point(x, y)
)
new Point(x, y),
);
}
public shiftByPoint<T extends FaceLandmarks>(pt: Point): T {
return this.shiftBy(pt.x, pt.y)
return this.shiftBy(pt.x, pt.y);
}
/**
@ -77,49 +83,48 @@ export class FaceLandmarks implements IFaceLandmarks {
*/
public align(
detection?: FaceDetection | IRect | IBoundingBox | null,
options: { useDlibAlignment?: boolean, minBoxPadding?: number } = { }
options: { useDlibAlignment?: boolean, minBoxPadding?: number } = { },
): Box {
if (detection) {
const box = detection instanceof FaceDetection
? detection.box.floor()
: new Box(detection)
: new Box(detection);
return this.shiftBy(box.x, box.y).align(null, options)
return this.shiftBy(box.x, box.y).align(null, options);
}
const { useDlibAlignment, minBoxPadding } = Object.assign({}, { useDlibAlignment: false, minBoxPadding: 0.2 }, options)
const { useDlibAlignment, minBoxPadding } = { useDlibAlignment: false, minBoxPadding: 0.2, ...options };
if (useDlibAlignment) {
return this.alignDlib()
return this.alignDlib();
}
return this.alignMinBbox(minBoxPadding)
return this.alignMinBbox(minBoxPadding);
}
private alignDlib(): Box {
const centers = this.getRefPointsForAlignment();
const centers = this.getRefPointsForAlignment()
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers;
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude();
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2;
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
const size = Math.floor(eyeToMouthDist / relScale);
const size = Math.floor(eyeToMouthDist / relScale)
const refPoint = getCenterPoint(centers)
const refPoint = getCenterPoint(centers);
// TODO: pad in case rectangle is out of image bounds
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)))
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)))
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)));
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)));
return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y))
return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y));
}
private alignMinBbox(padding: number): Box {
const box = minBbox(this.positions)
return box.pad(box.width * padding, box.height * padding)
const box = minBbox(this.positions);
return box.pad(box.width * padding, box.height * padding);
}
protected getRefPointsForAlignment(): Point[] {
throw new Error('getRefPointsForAlignment not implemented by base class')
throw new Error('getRefPointsForAlignment not implemented by base class');
}
}

View File

@ -2,15 +2,13 @@ import { getCenterPoint } from '../utils/index';
import { FaceLandmarks } from './FaceLandmarks';
import { Point } from './Point';
export class FaceLandmarks5 extends FaceLandmarks {
protected getRefPointsForAlignment(): Point[] {
const pts = this.positions
const pts = this.positions;
return [
pts[0],
pts[1],
getCenterPoint([pts[3], pts[4]])
]
getCenterPoint([pts[3], pts[4]]),
];
}
}

View File

@ -4,38 +4,38 @@ import { Point } from './Point';
export class FaceLandmarks68 extends FaceLandmarks {
public getJawOutline(): Point[] {
return this.positions.slice(0, 17)
return this.positions.slice(0, 17);
}
public getLeftEyeBrow(): Point[] {
return this.positions.slice(17, 22)
return this.positions.slice(17, 22);
}
public getRightEyeBrow(): Point[] {
return this.positions.slice(22, 27)
return this.positions.slice(22, 27);
}
public getNose(): Point[] {
return this.positions.slice(27, 36)
return this.positions.slice(27, 36);
}
public getLeftEye(): Point[] {
return this.positions.slice(36, 42)
return this.positions.slice(36, 42);
}
public getRightEye(): Point[] {
return this.positions.slice(42, 48)
return this.positions.slice(42, 48);
}
public getMouth(): Point[] {
return this.positions.slice(48, 68)
return this.positions.slice(48, 68);
}
protected getRefPointsForAlignment(): Point[] {
return [
this.getLeftEye(),
this.getRightEye(),
this.getMouth()
].map(getCenterPoint)
this.getMouth(),
].map(getCenterPoint);
}
}

View File

@ -7,17 +7,19 @@ export interface IFaceMatch {
export class FaceMatch implements IFaceMatch {
private _label: string
private _distance: number
constructor(label: string, distance: number) {
this._label = label
this._distance = distance
this._label = label;
this._distance = distance;
}
public get label(): string { return this._label }
public get distance(): number { return this._distance }
public get label(): string { return this._label; }
public get distance(): number { return this._distance; }
public toString(withDistance: boolean = true): string {
return `${this.label}${withDistance ? ` (${round(this.distance)})` : ''}`
return `${this.label}${withDistance ? ` (${round(this.distance)})` : ''}`;
}
}

View File

@ -4,22 +4,20 @@ import { Box } from './Box';
import { IRect } from './Rect';
export class LabeledBox extends Box<LabeledBox> {
public static assertIsValidLabeledBox(box: any, callee: string) {
Box.assertIsValidBox(box, callee)
Box.assertIsValidBox(box, callee);
if (!isValidNumber(box.label)) {
throw new Error(`${callee} - expected property label (${box.label}) to be a number`)
throw new Error(`${callee} - expected property label (${box.label}) to be a number`);
}
}
private _label: number
constructor(box: IBoundingBox | IRect | any, label: number) {
super(box)
this._label = label
super(box);
this._label = label;
}
public get label(): number { return this._label }
public get label(): number { return this._label; }
}

View File

@ -1,35 +1,34 @@
export class LabeledFaceDescriptors {
private _label: string
private _descriptors: Float32Array[]
constructor(label: string, descriptors: Float32Array[]) {
if (!(typeof label === 'string')) {
throw new Error('LabeledFaceDescriptors - constructor expected label to be a string')
throw new Error('LabeledFaceDescriptors - constructor expected label to be a string');
}
if (!Array.isArray(descriptors) || descriptors.some(desc => !(desc instanceof Float32Array))) {
throw new Error('LabeledFaceDescriptors - constructor expected descriptors to be an array of Float32Array')
if (!Array.isArray(descriptors) || descriptors.some((desc) => !(desc instanceof Float32Array))) {
throw new Error('LabeledFaceDescriptors - constructor expected descriptors to be an array of Float32Array');
}
this._label = label
this._descriptors = descriptors
this._label = label;
this._descriptors = descriptors;
}
public get label(): string { return this._label }
public get descriptors(): Float32Array[] { return this._descriptors }
public get label(): string { return this._label; }
public get descriptors(): Float32Array[] { return this._descriptors; }
public toJSON(): any {
return {
label: this.label,
descriptors: this.descriptors.map((d) => Array.from(d))
descriptors: this.descriptors.map((d) => Array.from(d)),
};
}
public static fromJSON(json: any): LabeledFaceDescriptors {
const descriptors = json.descriptors.map((d: any) => {
return new Float32Array(d);
});
const descriptors = json.descriptors.map((d: any) => new Float32Array(d));
return new LabeledFaceDescriptors(json.label, descriptors);
}
}

View File

@ -4,9 +4,13 @@ import { IRect, Rect } from './Rect';
export class ObjectDetection {
private _score: number
private _classScore: number
private _className: string
private _box: Rect
private _imageDims: Dimensions
constructor(
@ -14,23 +18,30 @@ export class ObjectDetection {
classScore: number,
className: string,
relativeBox: IRect,
imageDims: IDimensions
imageDims: IDimensions,
) {
this._imageDims = new Dimensions(imageDims.width, imageDims.height)
this._score = score
this._classScore = classScore
this._className = className
this._box = new Box(relativeBox).rescale(this._imageDims)
this._imageDims = new Dimensions(imageDims.width, imageDims.height);
this._score = score;
this._classScore = classScore;
this._className = className;
this._box = new Box(relativeBox).rescale(this._imageDims);
}
public get score(): number { return this._score }
public get classScore(): number { return this._classScore }
public get className(): string { return this._className }
public get box(): Box { return this._box }
public get imageDims(): Dimensions { return this._imageDims }
public get imageWidth(): number { return this.imageDims.width }
public get imageHeight(): number { return this.imageDims.height }
public get relativeBox(): Box { return new Box(this._box).rescale(this.imageDims.reverse()) }
public get score(): number { return this._score; }
public get classScore(): number { return this._classScore; }
public get className(): string { return this._className; }
public get box(): Box { return this._box; }
public get imageDims(): Dimensions { return this._imageDims; }
public get imageWidth(): number { return this.imageDims.width; }
public get imageHeight(): number { return this.imageDims.height; }
public get relativeBox(): Box { return new Box(this._box).rescale(this.imageDims.reverse()); }
public forSize(width: number, height: number): ObjectDetection {
return new ObjectDetection(
@ -38,7 +49,7 @@ export class ObjectDetection {
this.classScore,
this.className,
this.relativeBox,
{ width, height}
)
{ width, height },
);
}
}

View File

@ -5,41 +5,43 @@ export interface IPoint {
export class Point implements IPoint {
private _x: number
private _y: number
constructor(x: number, y: number) {
this._x = x
this._y = y
this._x = x;
this._y = y;
}
get x(): number { return this._x }
get y(): number { return this._y }
get x(): number { return this._x; }
get y(): number { return this._y; }
public add(pt: IPoint): Point {
return new Point(this.x + pt.x, this.y + pt.y)
return new Point(this.x + pt.x, this.y + pt.y);
}
public sub(pt: IPoint): Point {
return new Point(this.x - pt.x, this.y - pt.y)
return new Point(this.x - pt.x, this.y - pt.y);
}
public mul(pt: IPoint): Point {
return new Point(this.x * pt.x, this.y * pt.y)
return new Point(this.x * pt.x, this.y * pt.y);
}
public div(pt: IPoint): Point {
return new Point(this.x / pt.x, this.y / pt.y)
return new Point(this.x / pt.x, this.y / pt.y);
}
public abs(): Point {
return new Point(Math.abs(this.x), Math.abs(this.y))
return new Point(Math.abs(this.x), Math.abs(this.y));
}
public magnitude(): number {
return Math.sqrt(Math.pow(this.x, 2) + Math.pow(this.y, 2))
return Math.sqrt((this.x ** 2) + (this.y ** 2));
}
public floor(): Point {
return new Point(Math.floor(this.x), Math.floor(this.y))
return new Point(Math.floor(this.x), Math.floor(this.y));
}
}

View File

@ -4,28 +4,28 @@ import { LabeledBox } from './LabeledBox';
import { IRect } from './Rect';
export class PredictedBox extends LabeledBox {
public static assertIsValidPredictedBox(box: any, callee: string) {
LabeledBox.assertIsValidLabeledBox(box, callee)
LabeledBox.assertIsValidLabeledBox(box, callee);
if (
!isValidProbablitiy(box.score)
|| !isValidProbablitiy(box.classScore)
) {
throw new Error(`${callee} - expected properties score (${box.score}) and (${box.classScore}) to be a number between [0, 1]`)
throw new Error(`${callee} - expected properties score (${box.score}) and (${box.classScore}) to be a number between [0, 1]`);
}
}
private _score: number
private _classScore: number
constructor(box: IBoundingBox | IRect | any, label: number, score: number, classScore: number) {
super(box, label)
this._score = score
this._classScore = classScore
super(box, label);
this._score = score;
this._classScore = classScore;
}
public get score(): number { return this._score }
public get classScore(): number { return this._classScore }
public get score(): number { return this._score; }
public get classScore(): number { return this._classScore; }
}

View File

@ -9,6 +9,8 @@ export interface IRect {
export class Rect extends Box<Rect> implements IRect {
constructor(x: number, y: number, width: number, height: number, allowNegativeDimensions: boolean = false) {
super({ x, y, width, height }, allowNegativeDimensions)
super({
x, y, width, height,
}, allowNegativeDimensions);
}
}

View File

@ -1,14 +1,14 @@
export * from './BoundingBox'
export * from './Box'
export * from './Dimensions'
export * from './BoundingBox';
export * from './Box';
export * from './Dimensions';
export * from './FaceDetection';
export * from './FaceLandmarks';
export * from './FaceLandmarks5';
export * from './FaceLandmarks68';
export * from './FaceMatch';
export * from './LabeledBox'
export * from './LabeledBox';
export * from './LabeledFaceDescriptors';
export * from './ObjectDetection'
export * from './Point'
export * from './PredictedBox'
export * from './Rect'
export * from './ObjectDetection';
export * from './Point';
export * from './PredictedBox';
export * from './Rect';

View File

@ -6,14 +6,14 @@ export function convLayer(
x: tf.Tensor4D,
params: ConvParams,
padding: 'valid' | 'same' = 'same',
withRelu: boolean = false
withRelu: boolean = false,
): tf.Tensor4D {
return tf.tidy(() => {
const out = tf.add(
tf.conv2d(x, params.filters, [1, 1], padding),
params.bias
) as tf.Tensor4D
params.bias,
) as tf.Tensor4D;
return withRelu ? tf.relu(out) : out
})
return withRelu ? tf.relu(out) : out;
});
}

View File

@ -5,11 +5,11 @@ import { SeparableConvParams } from './types';
export function depthwiseSeparableConv(
x: tf.Tensor4D,
params: SeparableConvParams,
stride: [number, number]
stride: [number, number],
): tf.Tensor4D {
return tf.tidy(() => {
let out = tf.separableConv2d(x, params.depthwise_filter, params.pointwise_filter, stride, 'same')
out = tf.add(out, params.bias)
return out
})
let out = tf.separableConv2d(x, params.depthwise_filter, params.pointwise_filter, stride, 'same');
out = tf.add(out, params.bias);
return out;
});
}

View File

@ -1,9 +1,9 @@
import { ParamMapping } from './types';
export function disposeUnusedWeightTensors(weightMap: any, paramMappings: ParamMapping[]) {
Object.keys(weightMap).forEach(path => {
if (!paramMappings.some(pm => pm.originalPath === path)) {
weightMap[path].dispose()
Object.keys(weightMap).forEach((path) => {
if (!paramMappings.some((pm) => pm.originalPath === path)) {
weightMap[path].dispose();
}
})
});
}

View File

@ -4,28 +4,25 @@ import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
export function extractConvParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[]
paramMappings: ParamMapping[],
) {
return function(
return (
channelsIn: number,
channelsOut: number,
filterSize: number,
mappedPrefix: string
): ConvParams {
mappedPrefix: string,
): ConvParams => {
const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut]
)
const bias = tf.tensor1d(extractWeights(channelsOut))
[filterSize, filterSize, channelsIn, channelsOut],
);
const bias = tf.tensor1d(extractWeights(channelsOut));
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` }
)
return { filters, bias }
}
{ paramPath: `${mappedPrefix}/bias` },
);
return { filters, bias };
};
}

View File

@ -2,30 +2,26 @@ import * as tf from '../../dist/tfjs.esm.js';
import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
export function extractFCParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[]
paramMappings: ParamMapping[],
) {
return function(
return (
channelsIn: number,
channelsOut: number,
mappedPrefix: string
): FCParams {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
mappedPrefix: string,
): FCParams => {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]);
const fc_bias = tf.tensor1d(extractWeights(channelsOut));
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/bias` }
)
{ paramPath: `${mappedPrefix}/bias` },
);
return {
weights: fc_weights,
bias: fc_bias
}
}
bias: fc_bias,
};
};
}

View File

@ -4,43 +4,40 @@ import { ExtractWeightsFunction, ParamMapping, SeparableConvParams } from './typ
export function extractSeparableConvParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[]
paramMappings: ParamMapping[],
) {
return function(channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams {
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1])
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
const bias = tf.tensor1d(extractWeights(channelsOut))
return (channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams => {
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]);
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]);
const bias = tf.tensor1d(extractWeights(channelsOut));
paramMappings.push(
{ paramPath: `${mappedPrefix}/depthwise_filter` },
{ paramPath: `${mappedPrefix}/pointwise_filter` },
{ paramPath: `${mappedPrefix}/bias` }
)
{ paramPath: `${mappedPrefix}/bias` },
);
return new SeparableConvParams(
depthwise_filter,
pointwise_filter,
bias
)
}
bias,
);
};
}
export function loadSeparableConvParamsFactory(
extractWeightEntry: <T>(originalPath: string, paramRank: number) => T
// eslint-disable-next-line no-unused-vars
extractWeightEntry: <T>(originalPath: string, paramRank: number) => T,
) {
return function (prefix: string): SeparableConvParams {
const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4)
const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return (prefix: string): SeparableConvParams => {
const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4);
const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4);
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
return new SeparableConvParams(
depthwise_filter,
pointwise_filter,
bias
)
}
bias,
);
};
}

View File

@ -2,19 +2,17 @@ import { isTensor } from '../utils/index';
import { ParamMapping } from './types';
export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) {
return function<T> (originalPath: string, paramRank: number, mappedPath?: string): T {
const tensor = weightMap[originalPath]
return (originalPath: string, paramRank: number, mappedPath?: string) => {
const tensor = weightMap[originalPath];
if (!isTensor(tensor, paramRank)) {
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`)
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`);
}
paramMappings.push(
{ originalPath, paramPath: mappedPath || originalPath }
)
return tensor
}
{ originalPath, paramPath: mappedPath || originalPath },
);
return tensor;
};
}

View File

@ -1,18 +1,18 @@
export function extractWeightsFactory(weights: Float32Array) {
let remainingWeights = weights
let remainingWeights = weights;
function extractWeights(numWeights: number): Float32Array {
const ret = remainingWeights.slice(0, numWeights)
remainingWeights = remainingWeights.slice(numWeights)
return ret
const ret = remainingWeights.slice(0, numWeights);
remainingWeights = remainingWeights.slice(numWeights);
return ret;
}
function getRemainingWeights(): Float32Array {
return remainingWeights
return remainingWeights;
}
return {
extractWeights,
getRemainingWeights
}
getRemainingWeights,
};
}

View File

@ -4,12 +4,10 @@ import { FCParams } from './types';
export function fullyConnectedLayer(
x: tf.Tensor2D,
params: FCParams
params: FCParams,
): tf.Tensor2D {
return tf.tidy(() =>
tf.add(
return tf.tidy(() => tf.add(
tf.matMul(x, params.weights),
params.bias
)
)
params.bias,
));
}

View File

@ -1,33 +1,34 @@
export function getModelUris(uri: string | undefined, defaultModelName: string) {
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`;
if (!uri) {
return {
modelBaseUri: '',
manifestUri: defaultManifestFilename
}
manifestUri: defaultManifestFilename,
};
}
if (uri === '/') {
return {
modelBaseUri: '/',
manifestUri: `/${defaultManifestFilename}`
}
manifestUri: `/${defaultManifestFilename}`,
};
}
// eslint-disable-next-line no-nested-ternary
const protocol = uri.startsWith('http://') ? 'http://' : uri.startsWith('https://') ? 'https://' : '';
uri = uri.replace(protocol, '');
const parts = uri.split('/').filter(s => s)
const parts = uri.split('/').filter((s) => s);
const manifestFile = uri.endsWith('.json')
? parts[parts.length - 1]
: defaultManifestFilename
: defaultManifestFilename;
let modelBaseUri = protocol + (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/')
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri
let modelBaseUri = protocol + (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/');
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri;
return {
modelBaseUri,
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`
}
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`,
};
}

View File

@ -1,10 +1,10 @@
export * from './convLayer'
export * from './depthwiseSeparableConv'
export * from './disposeUnusedWeightTensors'
export * from './extractConvParamsFactory'
export * from './extractFCParamsFactory'
export * from './extractSeparableConvParamsFactory'
export * from './extractWeightEntryFactory'
export * from './extractWeightsFactory'
export * from './getModelUris'
export * from './types'
export * from './convLayer';
export * from './depthwiseSeparableConv';
export * from './disposeUnusedWeightTensors';
export * from './extractConvParamsFactory';
export * from './extractFCParamsFactory';
export * from './extractSeparableConvParamsFactory';
export * from './extractWeightEntryFactory';
export * from './extractWeightsFactory';
export * from './getModelUris';
export * from './types';

View File

@ -2,11 +2,12 @@ import * as tf from '../../dist/tfjs.esm.js';
import { ConvParams } from './types';
// eslint-disable-next-line no-unused-vars
export function loadConvParamsFactory(extractWeightEntry: <T>(originalPath: string, paramRank: number) => T) {
return function(prefix: string): ConvParams {
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return (prefix: string): ConvParams => {
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4);
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
return { filters, bias }
}
return { filters, bias };
};
}

View File

@ -1,5 +1,6 @@
import * as tf from '../../dist/tfjs.esm.js';
// eslint-disable-next-line no-unused-vars
export type ExtractWeightsFunction = (numWeights: number) => Float32Array
export type ParamMapping = {
@ -18,9 +19,14 @@ export type FCParams = {
}
export class SeparableConvParams {
// eslint-disable-next-line no-useless-constructor
constructor(
// eslint-disable-next-line no-unused-vars
public depthwise_filter: tf.Tensor4D,
// eslint-disable-next-line no-unused-vars
public pointwise_filter: tf.Tensor4D,
public bias: tf.Tensor1D
// eslint-disable-next-line no-unused-vars
public bias: tf.Tensor1D,
// eslint-disable-next-line no-empty-function
) {}
}

View File

@ -3,110 +3,115 @@ import * as tf from '../../dist/tfjs.esm.js';
import { Dimensions } from '../classes/Dimensions';
import { env } from '../env/index';
import { padToSquare } from '../ops/padToSquare';
import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils/index';
import {
computeReshapedDimensions, isTensor3D, isTensor4D, range,
} from '../utils/index';
import { createCanvasFromMedia } from './createCanvas';
import { imageToSquare } from './imageToSquare';
import { TResolvedNetInput } from './types';
export class NetInput {
private _imageTensors: Array<tf.Tensor3D | tf.Tensor4D> = []
private _canvases: HTMLCanvasElement[] = []
private _batchSize: number
private _treatAsBatchInput: boolean = false
private _inputDimensions: number[][] = []
private _inputSize: number
constructor(
inputs: Array<TResolvedNetInput>,
treatAsBatchInput: boolean = false
treatAsBatchInput: boolean = false,
) {
if (!Array.isArray(inputs)) {
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`)
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`);
}
this._treatAsBatchInput = treatAsBatchInput
this._batchSize = inputs.length
this._treatAsBatchInput = treatAsBatchInput;
this._batchSize = inputs.length;
inputs.forEach((input, idx) => {
if (isTensor3D(input)) {
this._imageTensors[idx] = input
this._inputDimensions[idx] = input.shape
return
this._imageTensors[idx] = input;
this._inputDimensions[idx] = input.shape;
return;
}
if (isTensor4D(input)) {
const batchSize = (input as any).shape[0]
const batchSize = (input as any).shape[0];
if (batchSize !== 1) {
throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`)
throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
}
this._imageTensors[idx] = input
this._inputDimensions[idx] = (input as any).shape.slice(1)
return
this._imageTensors[idx] = input;
this._inputDimensions[idx] = (input as any).shape.slice(1);
return;
}
const canvas = (input as any) instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input)
this._canvases[idx] = canvas
this._inputDimensions[idx] = [canvas.height, canvas.width, 3]
})
const canvas = (input as any) instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input);
this._canvases[idx] = canvas;
this._inputDimensions[idx] = [canvas.height, canvas.width, 3];
});
}
public get imageTensors(): Array<tf.Tensor3D | tf.Tensor4D> {
return this._imageTensors
return this._imageTensors;
}
public get canvases(): HTMLCanvasElement[] {
return this._canvases
return this._canvases;
}
public get isBatchInput(): boolean {
return this.batchSize > 1 || this._treatAsBatchInput
return this.batchSize > 1 || this._treatAsBatchInput;
}
public get batchSize(): number {
return this._batchSize
return this._batchSize;
}
public get inputDimensions(): number[][] {
return this._inputDimensions
return this._inputDimensions;
}
public get inputSize(): number | undefined {
return this._inputSize
return this._inputSize;
}
public get reshapedInputDimensions(): Dimensions[] {
return range(this.batchSize, 0, 1).map(
(_, batchIdx) => this.getReshapedInputDimensions(batchIdx)
)
(_, batchIdx) => this.getReshapedInputDimensions(batchIdx),
);
}
public getInput(batchIdx: number): tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement {
return this.canvases[batchIdx] || this.imageTensors[batchIdx]
return this.canvases[batchIdx] || this.imageTensors[batchIdx];
}
public getInputDimensions(batchIdx: number): number[] {
return this._inputDimensions[batchIdx]
return this._inputDimensions[batchIdx];
}
public getInputHeight(batchIdx: number): number {
return this._inputDimensions[batchIdx][0]
return this._inputDimensions[batchIdx][0];
}
public getInputWidth(batchIdx: number): number {
return this._inputDimensions[batchIdx][1]
return this._inputDimensions[batchIdx][1];
}
public getReshapedInputDimensions(batchIdx: number): Dimensions {
if (typeof this.inputSize !== 'number') {
throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet')
throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet');
}
const width = this.getInputWidth(batchIdx)
const height = this.getInputHeight(batchIdx)
return computeReshapedDimensions({ width, height }, this.inputSize)
const width = this.getInputWidth(batchIdx);
const height = this.getInputHeight(batchIdx);
return computeReshapedDimensions({ width, height }, this.inputSize);
}
/**
@ -119,39 +124,37 @@ export class NetInput {
* @returns The batch tensor.
*/
public toBatchTensor(inputSize: number, isCenterInputs: boolean = true): tf.Tensor4D {
this._inputSize = inputSize
this._inputSize = inputSize;
return tf.tidy(() => {
const inputTensors = range(this.batchSize, 0, 1).map(batchIdx => {
const input = this.getInput(batchIdx)
const inputTensors = range(this.batchSize, 0, 1).map((batchIdx) => {
const input = this.getInput(batchIdx);
if (input instanceof tf.Tensor) {
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>()
let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>();
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
imgTensor = padToSquare(imgTensor, isCenterInputs)
imgTensor = padToSquare(imgTensor, isCenterInputs);
if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize])
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]);
}
return imgTensor.as3D(inputSize, inputSize, 3)
return imgTensor.as3D(inputSize, inputSize, 3);
}
if (input instanceof env.getEnv().Canvas) {
return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs))
return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs));
}
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`)
})
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`);
});
// const batchTensor = tf.stack(inputTensors.map(t => t.toFloat())).as4D(this.batchSize, inputSize, inputSize, 3)
const batchTensor = tf.stack(inputTensors.map(t => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3)
const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3);
// const batchTensor = tf.stack(inputTensors.map(t => tf.Tensor.as4D(tf.cast(t, 'float32'))), this.batchSize, inputSize, inputSize, 3);
return batchTensor
})
return batchTensor;
});
}
}

View File

@ -2,27 +2,28 @@ import { env } from '../env/index';
import { isMediaLoaded } from './isMediaLoaded';
export function awaitMediaLoaded(media: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement) {
// eslint-disable-next-line consistent-return
return new Promise((resolve, reject) => {
if (media instanceof env.getEnv().Canvas || isMediaLoaded(media)) {
return resolve(null)
}
function onLoad(e: Event) {
if (!e.currentTarget) return
e.currentTarget.removeEventListener('load', onLoad)
e.currentTarget.removeEventListener('error', onError)
resolve(e)
return resolve(null);
}
function onError(e: Event) {
if (!e.currentTarget) return
e.currentTarget.removeEventListener('load', onLoad)
e.currentTarget.removeEventListener('error', onError)
reject(e)
if (!e.currentTarget) return;
// eslint-disable-next-line no-use-before-define
e.currentTarget.removeEventListener('load', onLoad);
e.currentTarget.removeEventListener('error', onError);
reject(e);
}
media.addEventListener('load', onLoad)
media.addEventListener('error', onError)
})
function onLoad(e: Event) {
if (!e.currentTarget) return;
e.currentTarget.removeEventListener('load', onLoad);
e.currentTarget.removeEventListener('error', onError);
resolve(e);
}
media.addEventListener('load', onLoad);
media.addEventListener('error', onError);
});
}

View File

@ -2,22 +2,16 @@ import { env } from '../env/index';
export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
return new Promise((resolve, reject) => {
if (!(buf instanceof Blob)) {
return reject('bufferToImage - expected buf to be of type: Blob')
}
const reader = new FileReader()
if (!(buf instanceof Blob)) reject(new Error('bufferToImage - expected buf to be of type: Blob'));
const reader = new FileReader();
reader.onload = () => {
if (typeof reader.result !== 'string') {
return reject('bufferToImage - expected reader.result to be a string, in onload')
}
const img = env.getEnv().createImageElement()
img.onload = () => resolve(img)
img.onerror = reject
img.src = reader.result
}
reader.onerror = reject
reader.readAsDataURL(buf)
})
if (typeof reader.result !== 'string') reject(new Error('bufferToImage - expected reader.result to be a string, in onload'));
const img = env.getEnv().createImageElement();
img.onload = () => resolve(img);
img.onerror = reject;
img.src = reader.result as string;
};
reader.onerror = reject;
reader.readAsDataURL(buf);
});
}

View File

@ -5,29 +5,27 @@ import { getMediaDimensions } from './getMediaDimensions';
import { isMediaLoaded } from './isMediaLoaded';
export function createCanvas({ width, height }: IDimensions): HTMLCanvasElement {
const { createCanvasElement } = env.getEnv()
const canvas = createCanvasElement()
canvas.width = width
canvas.height = height
return canvas
const { createCanvasElement } = env.getEnv();
const canvas = createCanvasElement();
canvas.width = width;
canvas.height = height;
return canvas;
}
export function createCanvasFromMedia(media: HTMLImageElement | HTMLVideoElement | ImageData, dims?: IDimensions): HTMLCanvasElement {
const { ImageData } = env.getEnv()
const { ImageData } = env.getEnv();
if (!(media instanceof ImageData) && !isMediaLoaded(media)) {
throw new Error('createCanvasFromMedia - media has not finished loading yet')
throw new Error('createCanvasFromMedia - media has not finished loading yet');
}
const { width, height } = dims || getMediaDimensions(media)
const canvas = createCanvas({ width, height })
const { width, height } = dims || getMediaDimensions(media);
const canvas = createCanvas({ width, height });
if (media instanceof ImageData) {
getContext2dOrThrow(canvas).putImageData(media, 0, 0)
getContext2dOrThrow(canvas).putImageData(media, 0, 0);
} else {
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height);
}
return canvas
return canvas;
}

View File

@ -16,31 +16,30 @@ import { isTensor3D, isTensor4D } from '../utils/index';
*/
export async function extractFaceTensors(
imageTensor: tf.Tensor3D | tf.Tensor4D,
detections: Array<FaceDetection | Rect>
detections: Array<FaceDetection | Rect>,
): Promise<tf.Tensor3D[]> {
if (!isTensor3D(imageTensor) && !isTensor4D(imageTensor)) {
throw new Error('extractFaceTensors - expected image tensor to be 3D or 4D')
throw new Error('extractFaceTensors - expected image tensor to be 3D or 4D');
}
if (isTensor4D(imageTensor) && imageTensor.shape[0] > 1) {
throw new Error('extractFaceTensors - batchSize > 1 not supported')
throw new Error('extractFaceTensors - batchSize > 1 not supported');
}
return tf.tidy(() => {
const [imgHeight, imgWidth, numChannels] = imageTensor.shape.slice(isTensor4D(imageTensor) ? 1 : 0)
const [imgHeight, imgWidth, numChannels] = imageTensor.shape.slice(isTensor4D(imageTensor) ? 1 : 0);
const boxes = detections.map(
det => det instanceof FaceDetection
(det) => (det instanceof FaceDetection
? det.forSize(imgWidth, imgHeight).box
: det
: det),
)
.map(box => box.clipAtImageBorders(imgWidth, imgHeight))
.map((box) => box.clipAtImageBorders(imgWidth, imgHeight));
const faceTensors = boxes.map(({ x, y, width, height }) =>
tf.slice3d(imageTensor.as3D(imgHeight, imgWidth, numChannels), [y, x, 0], [height, width, numChannels])
)
const faceTensors = boxes.map(({
x, y, width, height,
}) => tf.slice3d(imageTensor.as3D(imgHeight, imgWidth, numChannels), [y, x, 0], [height, width, numChannels]));
return faceTensors
})
return faceTensors;
});
}

View File

@ -16,38 +16,39 @@ import { TNetInput } from './types';
*/
export async function extractFaces(
input: TNetInput,
detections: Array<FaceDetection | Rect>
detections: Array<FaceDetection | Rect>,
): Promise<HTMLCanvasElement[]> {
const { Canvas } = env.getEnv();
const { Canvas } = env.getEnv()
let canvas = input as HTMLCanvasElement
let canvas = input as HTMLCanvasElement;
if (!(input instanceof Canvas)) {
const netInput = await toNetInput(input)
const netInput = await toNetInput(input);
if (netInput.batchSize > 1) {
throw new Error('extractFaces - batchSize > 1 not supported')
throw new Error('extractFaces - batchSize > 1 not supported');
}
const tensorOrCanvas = netInput.getInput(0)
const tensorOrCanvas = netInput.getInput(0);
canvas = tensorOrCanvas instanceof Canvas
? tensorOrCanvas
: await imageTensorToCanvas(tensorOrCanvas)
: await imageTensorToCanvas(tensorOrCanvas);
}
const ctx = getContext2dOrThrow(canvas)
const ctx = getContext2dOrThrow(canvas);
const boxes = detections.map(
det => det instanceof FaceDetection
(det) => (det instanceof FaceDetection
? det.forSize(canvas.width, canvas.height).box.floor()
: det
: det),
)
.map(box => box.clipAtImageBorders(canvas.width, canvas.height))
.map((box) => box.clipAtImageBorders(canvas.width, canvas.height));
return boxes.map(({ x, y, width, height }) => {
const faceImg = createCanvas({ width, height })
return boxes.map(({
x, y, width, height,
}) => {
const faceImg = createCanvas({ width, height });
getContext2dOrThrow(faceImg)
.putImageData(ctx.getImageData(x, y, width, height), 0, 0)
return faceImg
})
.putImageData(ctx.getImageData(x, y, width, height), 0, 0);
return faceImg;
});
}

View File

@ -2,11 +2,11 @@ import { bufferToImage } from './bufferToImage';
import { fetchOrThrow } from './fetchOrThrow';
export async function fetchImage(uri: string): Promise<HTMLImageElement> {
const res = await fetchOrThrow(uri)
const blob = await (res).blob()
const res = await fetchOrThrow(uri);
const blob = await (res).blob();
if (!blob.type.startsWith('image/')) {
throw new Error(`fetchImage - expected blob type to be of type image/*, instead have: ${blob.type}, for url: ${res.url}`)
throw new Error(`fetchImage - expected blob type to be of type image/*, instead have: ${blob.type}, for url: ${res.url}`);
}
return bufferToImage(blob)
return bufferToImage(blob);
}

View File

@ -1,5 +1,5 @@
import { fetchOrThrow } from './fetchOrThrow';
export async function fetchJson<T>(uri: string): Promise<T> {
return (await fetchOrThrow(uri)).json()
return (await fetchOrThrow(uri)).json();
}

View File

@ -1,5 +1,5 @@
import { fetchOrThrow } from './fetchOrThrow';
export async function fetchNetWeights(uri: string): Promise<Float32Array> {
return new Float32Array(await (await fetchOrThrow(uri)).arrayBuffer())
return new Float32Array(await (await fetchOrThrow(uri)).arrayBuffer());
}

View File

@ -2,13 +2,13 @@ import { env } from '../env/index';
export async function fetchOrThrow(
url: string,
init?: RequestInit
// eslint-disable-next-line no-undef
init?: RequestInit,
): Promise<Response> {
const fetch = env.getEnv().fetch
const res = await fetch(url, init)
const { fetch } = env.getEnv();
const res = await fetch(url, init);
if (!(res.status < 400)) {
throw new Error(`failed to fetch: (${res.status}) ${res.statusText}, from url: ${res.url}`)
throw new Error(`failed to fetch: (${res.status}) ${res.statusText}, from url: ${res.url}`);
}
return res
return res;
}

View File

@ -2,23 +2,22 @@ import { env } from '../env/index';
import { resolveInput } from './resolveInput';
export function getContext2dOrThrow(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D): CanvasRenderingContext2D {
const { Canvas, CanvasRenderingContext2D } = env.getEnv()
const { Canvas, CanvasRenderingContext2D } = env.getEnv();
if (canvasArg instanceof CanvasRenderingContext2D) {
return canvasArg
return canvasArg;
}
const canvas = resolveInput(canvasArg)
const canvas = resolveInput(canvasArg);
if (!(canvas instanceof Canvas)) {
throw new Error('resolveContext2d - expected canvas to be of instance of Canvas')
throw new Error('resolveContext2d - expected canvas to be of instance of Canvas');
}
const ctx = canvas.getContext('2d')
const ctx = canvas.getContext('2d');
if (!ctx) {
throw new Error('resolveContext2d - canvas 2d context is null')
throw new Error('resolveContext2d - canvas 2d context is null');
}
return ctx
return ctx;
}

View File

@ -2,14 +2,13 @@ import { Dimensions, IDimensions } from '../classes/Dimensions';
import { env } from '../env/index';
export function getMediaDimensions(input: HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | IDimensions): Dimensions {
const { Image, Video } = env.getEnv()
const { Image, Video } = env.getEnv();
if (input instanceof Image) {
return new Dimensions(input.naturalWidth, input.naturalHeight)
return new Dimensions(input.naturalWidth, input.naturalHeight);
}
if (input instanceof Video) {
return new Dimensions(input.videoWidth, input.videoHeight)
return new Dimensions(input.videoWidth, input.videoHeight);
}
return new Dimensions(input.width, input.height)
return new Dimensions(input.width, input.height);
}

View File

@ -5,16 +5,15 @@ import { isTensor4D } from '../utils/index';
export async function imageTensorToCanvas(
imgTensor: tf.Tensor,
canvas?: HTMLCanvasElement
canvas?: HTMLCanvasElement,
): Promise<HTMLCanvasElement> {
const targetCanvas = canvas || env.getEnv().createCanvasElement();
const targetCanvas = canvas || env.getEnv().createCanvasElement()
const [height, width, numChannels] = imgTensor.shape.slice(isTensor4D(imgTensor) ? 1 : 0);
const imgTensor3D = tf.tidy(() => imgTensor.as3D(height, width, numChannels).toInt());
await tf.browser.toPixels(imgTensor3D, targetCanvas);
const [height, width, numChannels] = imgTensor.shape.slice(isTensor4D(imgTensor) ? 1 : 0)
const imgTensor3D = tf.tidy(() => imgTensor.as3D(height, width, numChannels).toInt())
await tf.browser.toPixels(imgTensor3D, targetCanvas)
imgTensor3D.dispose();
imgTensor3D.dispose()
return targetCanvas
return targetCanvas;
}

View File

@ -4,25 +4,24 @@ import { getContext2dOrThrow } from './getContext2dOrThrow';
import { getMediaDimensions } from './getMediaDimensions';
export function imageToSquare(input: HTMLImageElement | HTMLCanvasElement, inputSize: number, centerImage: boolean = false) {
const { Image, Canvas } = env.getEnv()
const { Image, Canvas } = env.getEnv();
if (!(input instanceof Image || input instanceof Canvas)) {
throw new Error('imageToSquare - expected arg0 to be HTMLImageElement | HTMLCanvasElement')
throw new Error('imageToSquare - expected arg0 to be HTMLImageElement | HTMLCanvasElement');
}
const dims = getMediaDimensions(input)
const scale = inputSize / Math.max(dims.height, dims.width)
const width = scale * dims.width
const height = scale * dims.height
const dims = getMediaDimensions(input);
const scale = inputSize / Math.max(dims.height, dims.width);
const width = scale * dims.width;
const height = scale * dims.height;
const targetCanvas = createCanvas({ width: inputSize, height: inputSize })
const inputCanvas = input instanceof Canvas ? input : createCanvasFromMedia(input)
const targetCanvas = createCanvas({ width: inputSize, height: inputSize });
const inputCanvas = input instanceof Canvas ? input : createCanvasFromMedia(input);
const offset = Math.abs(width - height) / 2
const dx = centerImage && width < height ? offset : 0
const dy = centerImage && height < width ? offset : 0
getContext2dOrThrow(targetCanvas).drawImage(inputCanvas, dx, dy, width, height)
const offset = Math.abs(width - height) / 2;
const dx = centerImage && width < height ? offset : 0;
const dy = centerImage && height < width ? offset : 0;
getContext2dOrThrow(targetCanvas).drawImage(inputCanvas, dx, dy, width, height);
return targetCanvas
return targetCanvas;
}

View File

@ -1,21 +1,21 @@
export * from './awaitMediaLoaded'
export * from './bufferToImage'
export * from './createCanvas'
export * from './extractFaces'
export * from './extractFaceTensors'
export * from './fetchImage'
export * from './fetchJson'
export * from './fetchNetWeights'
export * from './fetchOrThrow'
export * from './getContext2dOrThrow'
export * from './getMediaDimensions'
export * from './imageTensorToCanvas'
export * from './imageToSquare'
export * from './isMediaElement'
export * from './isMediaLoaded'
export * from './loadWeightMap'
export * from './matchDimensions'
export * from './NetInput'
export * from './resolveInput'
export * from './toNetInput'
export * from './types'
export * from './awaitMediaLoaded';
export * from './bufferToImage';
export * from './createCanvas';
export * from './extractFaces';
export * from './extractFaceTensors';
export * from './fetchImage';
export * from './fetchJson';
export * from './fetchNetWeights';
export * from './fetchOrThrow';
export * from './getContext2dOrThrow';
export * from './getMediaDimensions';
export * from './imageTensorToCanvas';
export * from './imageToSquare';
export * from './isMediaElement';
export * from './isMediaLoaded';
export * from './loadWeightMap';
export * from './matchDimensions';
export * from './NetInput';
export * from './resolveInput';
export * from './toNetInput';
export * from './types';

View File

@ -1,10 +1,9 @@
import { env } from '../env/index';
export function isMediaElement(input: any) {
const { Image, Canvas, Video } = env.getEnv()
const { Image, Canvas, Video } = env.getEnv();
return input instanceof Image
|| input instanceof Canvas
|| input instanceof Video
|| input instanceof Video;
}

View File

@ -1,9 +1,8 @@
import { env } from '../env/index';
export function isMediaLoaded(media: HTMLImageElement | HTMLVideoElement) : boolean {
const { Image, Video } = env.getEnv()
const { Image, Video } = env.getEnv();
return (media instanceof Image && media.complete)
|| (media instanceof Video && media.readyState >= 3)
|| (media instanceof Video && media.readyState >= 3);
}

View File

@ -7,8 +7,8 @@ export async function loadWeightMap(
uri: string | undefined,
defaultModelName: string,
): Promise<tf.NamedTensorMap> {
const { manifestUri, modelBaseUri } = getModelUris(uri, defaultModelName)
let manifest = await fetchJson<tf.io.WeightsManifestConfig>(manifestUri)
const { manifestUri, modelBaseUri } = getModelUris(uri, defaultModelName);
const manifest = await fetchJson<tf.io.WeightsManifestConfig>(manifestUri);
// if (manifest['weightsManifest']) manifest = manifest['weightsManifest'];
return tf.io.loadWeights(manifest, modelBaseUri)
return tf.io.loadWeights(manifest, modelBaseUri);
}

View File

@ -4,8 +4,8 @@ import { getMediaDimensions } from './getMediaDimensions';
export function matchDimensions(input: IDimensions, reference: IDimensions, useMediaDimensions: boolean = false) {
const { width, height } = useMediaDimensions
? getMediaDimensions(reference)
: reference
input.width = width
input.height = height
return { width, height }
: reference;
input.width = width;
input.height = height;
return { width, height };
}

View File

@ -2,7 +2,7 @@ import { env } from '../env/index';
export function resolveInput(arg: string | any) {
if (!env.isNodejs() && typeof arg === 'string') {
return document.getElementById(arg)
return document.getElementById(arg);
}
return arg
return arg;
}

View File

@ -14,44 +14,43 @@ import { TNetInput } from './types';
*/
export async function toNetInput(inputs: TNetInput): Promise<NetInput> {
if (inputs instanceof NetInput) {
return inputs
return inputs;
}
let inputArgArray = Array.isArray(inputs)
const inputArgArray = Array.isArray(inputs)
? inputs
: [inputs]
: [inputs];
if (!inputArgArray.length) {
throw new Error('toNetInput - empty array passed as input')
throw new Error('toNetInput - empty array passed as input');
}
const getIdxHint = (idx: number) => Array.isArray(inputs) ? ` at input index ${idx}:` : ''
const getIdxHint = (idx: number) => (Array.isArray(inputs) ? ` at input index ${idx}:` : '');
const inputArray = inputArgArray.map(resolveInput)
const inputArray = inputArgArray.map(resolveInput);
inputArray.forEach((input, i) => {
if (!isMediaElement(input) && !isTensor3D(input) && !isTensor4D(input)) {
if (typeof inputArgArray[i] === 'string') {
throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`)
throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`);
}
throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`)
throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`);
}
if (isTensor4D(input)) {
// if tf.Tensor4D is passed in the input array, the batch size has to be 1
const batchSize = input.shape[0]
const batchSize = input.shape[0];
if (batchSize !== 1) {
throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`)
throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
}
}
})
});
// wait for all media elements being loaded
await Promise.all(
inputArray.map(input => isMediaElement(input) && awaitMediaLoaded(input))
)
inputArray.map((input) => isMediaElement(input) && awaitMediaLoaded(input)),
);
return new NetInput(inputArray, Array.isArray(inputs))
return new NetInput(inputArray, Array.isArray(inputs));
}

View File

@ -1,6 +1,9 @@
/* eslint-disable max-classes-per-file */
import { Box, IBoundingBox, IRect } from '../classes/index';
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
import { AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions } from './DrawTextField';
import {
AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions,
} from './DrawTextField';
export interface IDrawBoxOptions {
boxColor?: string
@ -11,49 +14,57 @@ export interface IDrawBoxOptions {
export class DrawBoxOptions {
public boxColor: string
public lineWidth: number
public drawLabelOptions: DrawTextFieldOptions
public label?: string
constructor(options: IDrawBoxOptions = {}) {
const { boxColor, lineWidth, label, drawLabelOptions } = options
this.boxColor = boxColor || 'rgba(0, 0, 255, 1)'
this.lineWidth = lineWidth || 2
this.label = label
const {
boxColor, lineWidth, label, drawLabelOptions,
} = options;
this.boxColor = boxColor || 'rgba(0, 0, 255, 1)';
this.lineWidth = lineWidth || 2;
this.label = label;
const defaultDrawLabelOptions = {
anchorPosition: AnchorPosition.BOTTOM_LEFT,
backgroundColor: this.boxColor
}
this.drawLabelOptions = new DrawTextFieldOptions(Object.assign({}, defaultDrawLabelOptions, drawLabelOptions))
backgroundColor: this.boxColor,
};
this.drawLabelOptions = new DrawTextFieldOptions({ ...defaultDrawLabelOptions, ...drawLabelOptions });
}
}
export class DrawBox {
public box: Box
public options: DrawBoxOptions
constructor(
box: IBoundingBox | IRect,
options: IDrawBoxOptions = {}
options: IDrawBoxOptions = {},
) {
this.box = new Box(box)
this.options = new DrawBoxOptions(options)
this.box = new Box(box);
this.options = new DrawBoxOptions(options);
}
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
const ctx = getContext2dOrThrow(canvasArg)
const ctx = getContext2dOrThrow(canvasArg);
const { boxColor, lineWidth } = this.options
const { boxColor, lineWidth } = this.options;
const { x, y, width, height } = this.box
ctx.strokeStyle = boxColor
ctx.lineWidth = lineWidth
ctx.strokeRect(x, y, width, height)
const {
x, y, width, height,
} = this.box;
ctx.strokeStyle = boxColor;
ctx.lineWidth = lineWidth;
ctx.strokeRect(x, y, width, height);
const { label } = this.options
const { label } = this.options;
if (label) {
new DrawTextField([label], { x: x - (lineWidth / 2), y }, this.options.drawLabelOptions).draw(canvasArg)
new DrawTextField([label], { x: x - (lineWidth / 2), y }, this.options.drawLabelOptions).draw(canvasArg);
}
}
}

View File

@ -1,3 +1,4 @@
/* eslint-disable max-classes-per-file */
import { IPoint } from '../classes/index';
import { FaceLandmarks } from '../classes/FaceLandmarks';
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
@ -17,62 +18,72 @@ export interface IDrawFaceLandmarksOptions {
export class DrawFaceLandmarksOptions {
public drawLines: boolean
public drawPoints: boolean
public lineWidth: number
public pointSize: number
public lineColor: string
public pointColor: string
constructor(options: IDrawFaceLandmarksOptions = {}) {
const { drawLines = true, drawPoints = true, lineWidth, lineColor, pointSize, pointColor } = options
this.drawLines = drawLines
this.drawPoints = drawPoints
this.lineWidth = lineWidth || 1
this.pointSize = pointSize || 2
this.lineColor = lineColor || 'rgba(0, 255, 255, 1)'
this.pointColor = pointColor || 'rgba(255, 0, 255, 1)'
const {
drawLines = true, drawPoints = true, lineWidth, lineColor, pointSize, pointColor,
} = options;
this.drawLines = drawLines;
this.drawPoints = drawPoints;
this.lineWidth = lineWidth || 1;
this.pointSize = pointSize || 2;
this.lineColor = lineColor || 'rgba(0, 255, 255, 1)';
this.pointColor = pointColor || 'rgba(255, 0, 255, 1)';
}
}
export class DrawFaceLandmarks {
public faceLandmarks: FaceLandmarks
public options: DrawFaceLandmarksOptions
constructor(
faceLandmarks: FaceLandmarks,
options: IDrawFaceLandmarksOptions = {}
options: IDrawFaceLandmarksOptions = {},
) {
this.faceLandmarks = faceLandmarks
this.options = new DrawFaceLandmarksOptions(options)
this.faceLandmarks = faceLandmarks;
this.options = new DrawFaceLandmarksOptions(options);
}
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
const ctx = getContext2dOrThrow(canvasArg)
const ctx = getContext2dOrThrow(canvasArg);
const { drawLines, drawPoints, lineWidth, lineColor, pointSize, pointColor } = this.options
const {
drawLines, drawPoints, lineWidth, lineColor, pointSize, pointColor,
} = this.options;
if (drawLines && this.faceLandmarks instanceof FaceLandmarks68) {
ctx.strokeStyle = lineColor
ctx.lineWidth = lineWidth
drawContour(ctx, this.faceLandmarks.getJawOutline())
drawContour(ctx, this.faceLandmarks.getLeftEyeBrow())
drawContour(ctx, this.faceLandmarks.getRightEyeBrow())
drawContour(ctx, this.faceLandmarks.getNose())
drawContour(ctx, this.faceLandmarks.getLeftEye(), true)
drawContour(ctx, this.faceLandmarks.getRightEye(), true)
drawContour(ctx, this.faceLandmarks.getMouth(), true)
ctx.strokeStyle = lineColor;
ctx.lineWidth = lineWidth;
drawContour(ctx, this.faceLandmarks.getJawOutline());
drawContour(ctx, this.faceLandmarks.getLeftEyeBrow());
drawContour(ctx, this.faceLandmarks.getRightEyeBrow());
drawContour(ctx, this.faceLandmarks.getNose());
drawContour(ctx, this.faceLandmarks.getLeftEye(), true);
drawContour(ctx, this.faceLandmarks.getRightEye(), true);
drawContour(ctx, this.faceLandmarks.getMouth(), true);
}
if (drawPoints) {
ctx.strokeStyle = pointColor
ctx.fillStyle = pointColor
ctx.strokeStyle = pointColor;
ctx.fillStyle = pointColor;
const drawPoint = (pt: IPoint) => {
ctx.beginPath()
ctx.arc(pt.x, pt.y, pointSize, 0, 2 * Math.PI)
ctx.fill()
}
this.faceLandmarks.positions.forEach(drawPoint)
ctx.beginPath();
ctx.arc(pt.x, pt.y, pointSize, 0, 2 * Math.PI);
ctx.fill();
};
this.faceLandmarks.positions.forEach(drawPoint);
}
}
}
@ -81,17 +92,18 @@ export type DrawFaceLandmarksInput = FaceLandmarks | WithFaceLandmarks<WithFaceD
export function drawFaceLandmarks(
canvasArg: string | HTMLCanvasElement,
faceLandmarks: DrawFaceLandmarksInput | Array<DrawFaceLandmarksInput>
faceLandmarks: DrawFaceLandmarksInput | Array<DrawFaceLandmarksInput>,
) {
const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks]
faceLandmarksArray.forEach(f => {
const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks];
faceLandmarksArray.forEach((f) => {
// eslint-disable-next-line no-nested-ternary
const landmarks = f instanceof FaceLandmarks
? f
: (isWithFaceLandmarks(f) ? f.landmarks : undefined)
: (isWithFaceLandmarks(f) ? f.landmarks : undefined);
if (!landmarks) {
throw new Error('drawFaceLandmarks - expected faceExpressions to be FaceLandmarks | WithFaceLandmarks<WithFaceDetection<{}>> or array thereof')
throw new Error('drawFaceLandmarks - expected faceExpressions to be FaceLandmarks | WithFaceLandmarks<WithFaceDetection<{}>> or array thereof');
}
new DrawFaceLandmarks(landmarks).draw(canvasArg)
})
new DrawFaceLandmarks(landmarks).draw(canvasArg);
});
}

View File

@ -1,11 +1,17 @@
/* eslint-disable max-classes-per-file */
import { IDimensions, IPoint } from '../classes/index';
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
import { resolveInput } from '../dom/resolveInput';
// eslint-disable-next-line no-shadow
export enum AnchorPosition {
// eslint-disable-next-line no-unused-vars
TOP_LEFT = 'TOP_LEFT',
// eslint-disable-next-line no-unused-vars
TOP_RIGHT = 'TOP_RIGHT',
// eslint-disable-next-line no-unused-vars
BOTTOM_LEFT = 'BOTTOM_LEFT',
// eslint-disable-next-line no-unused-vars
BOTTOM_RIGHT = 'BOTTOM_RIGHT'
}
@ -20,89 +26,101 @@ export interface IDrawTextFieldOptions {
export class DrawTextFieldOptions implements IDrawTextFieldOptions {
public anchorPosition: AnchorPosition
public backgroundColor: string
public fontColor: string
public fontSize: number
public fontStyle: string
public padding: number
constructor(options: IDrawTextFieldOptions = {}) {
const { anchorPosition, backgroundColor, fontColor, fontSize, fontStyle, padding } = options
this.anchorPosition = anchorPosition || AnchorPosition.TOP_LEFT
this.backgroundColor = backgroundColor || 'rgba(0, 0, 0, 0.5)'
this.fontColor = fontColor || 'rgba(255, 255, 255, 1)'
this.fontSize = fontSize || 14
this.fontStyle = fontStyle || 'Georgia'
this.padding = padding || 4
const {
anchorPosition, backgroundColor, fontColor, fontSize, fontStyle, padding,
} = options;
this.anchorPosition = anchorPosition || AnchorPosition.TOP_LEFT;
this.backgroundColor = backgroundColor || 'rgba(0, 0, 0, 0.5)';
this.fontColor = fontColor || 'rgba(255, 255, 255, 1)';
this.fontSize = fontSize || 14;
this.fontStyle = fontStyle || 'Georgia';
this.padding = padding || 4;
}
}
export class DrawTextField {
public text: string[]
public anchor : IPoint
public options: DrawTextFieldOptions
constructor(
text: string | string[] | DrawTextField,
anchor: IPoint,
options: IDrawTextFieldOptions = {}
options: IDrawTextFieldOptions = {},
) {
// eslint-disable-next-line no-nested-ternary
this.text = typeof text === 'string'
? [text]
: (text instanceof DrawTextField ? text.text : text)
this.anchor = anchor
this.options = new DrawTextFieldOptions(options)
: (text instanceof DrawTextField ? text.text : text);
this.anchor = anchor;
this.options = new DrawTextFieldOptions(options);
}
measureWidth(ctx: CanvasRenderingContext2D): number {
const { padding } = this.options
return this.text.map(l => ctx.measureText(l).width).reduce((w0, w1) => w0 < w1 ? w1 : w0, 0) + (2 * padding)
const { padding } = this.options;
return this.text.map((l) => ctx.measureText(l).width).reduce((w0, w1) => (w0 < w1 ? w1 : w0), 0) + (2 * padding);
}
measureHeight(): number {
const { fontSize, padding } = this.options
return this.text.length * fontSize + (2 * padding)
const { fontSize, padding } = this.options;
return this.text.length * fontSize + (2 * padding);
}
getUpperLeft(ctx: CanvasRenderingContext2D, canvasDims?: IDimensions): IPoint {
const { anchorPosition } = this.options
const isShiftLeft = anchorPosition === AnchorPosition.BOTTOM_RIGHT || anchorPosition === AnchorPosition.TOP_RIGHT
const isShiftTop = anchorPosition === AnchorPosition.BOTTOM_LEFT || anchorPosition === AnchorPosition.BOTTOM_RIGHT
const { anchorPosition } = this.options;
const isShiftLeft = anchorPosition === AnchorPosition.BOTTOM_RIGHT || anchorPosition === AnchorPosition.TOP_RIGHT;
const isShiftTop = anchorPosition === AnchorPosition.BOTTOM_LEFT || anchorPosition === AnchorPosition.BOTTOM_RIGHT;
const textFieldWidth = this.measureWidth(ctx)
const textFieldHeight = this.measureHeight()
const x = (isShiftLeft ? this.anchor.x - textFieldWidth : this.anchor.x)
const y = isShiftTop ? this.anchor.y - textFieldHeight : this.anchor.y
const textFieldWidth = this.measureWidth(ctx);
const textFieldHeight = this.measureHeight();
const x = (isShiftLeft ? this.anchor.x - textFieldWidth : this.anchor.x);
const y = isShiftTop ? this.anchor.y - textFieldHeight : this.anchor.y;
// adjust anchor if text box exceeds canvas borders
if (canvasDims) {
const { width, height } = canvasDims
const newX = Math.max(Math.min(x, width - textFieldWidth), 0)
const newY = Math.max(Math.min(y, height - textFieldHeight), 0)
return { x: newX, y: newY }
const { width, height } = canvasDims;
const newX = Math.max(Math.min(x, width - textFieldWidth), 0);
const newY = Math.max(Math.min(y, height - textFieldHeight), 0);
return { x: newX, y: newY };
}
return { x, y }
return { x, y };
}
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
const canvas = resolveInput(canvasArg)
const ctx = getContext2dOrThrow(canvas)
const canvas = resolveInput(canvasArg);
const ctx = getContext2dOrThrow(canvas);
const { backgroundColor, fontColor, fontSize, fontStyle, padding } = this.options
const {
backgroundColor, fontColor, fontSize, fontStyle, padding,
} = this.options;
ctx.font = `${fontSize}px ${fontStyle}`
const maxTextWidth = this.measureWidth(ctx)
const textHeight = this.measureHeight()
ctx.font = `${fontSize}px ${fontStyle}`;
const maxTextWidth = this.measureWidth(ctx);
const textHeight = this.measureHeight();
ctx.fillStyle = backgroundColor
const upperLeft = this.getUpperLeft(ctx, canvas)
ctx.fillRect(upperLeft.x, upperLeft.y, maxTextWidth, textHeight)
ctx.fillStyle = backgroundColor;
const upperLeft = this.getUpperLeft(ctx, canvas);
ctx.fillRect(upperLeft.x, upperLeft.y, maxTextWidth, textHeight);
ctx.fillStyle = fontColor;
this.text.forEach((textLine, i) => {
const x = padding + upperLeft.x
const y = padding + upperLeft.y + ((i + 1) * fontSize)
ctx.fillText(textLine, x, y)
})
const x = padding + upperLeft.x;
const y = padding + upperLeft.y + ((i + 1) * fontSize);
ctx.fillText(textLine, x, y);
});
}
}

View File

@ -3,26 +3,26 @@ import { Point } from '../classes/index';
export function drawContour(
ctx: CanvasRenderingContext2D,
points: Point[],
isClosed: boolean = false
isClosed: boolean = false,
) {
ctx.beginPath()
ctx.beginPath();
points.slice(1).forEach(({ x, y }, prevIdx) => {
const from = points[prevIdx]
ctx.moveTo(from.x, from.y)
ctx.lineTo(x, y)
})
const from = points[prevIdx];
ctx.moveTo(from.x, from.y);
ctx.lineTo(x, y);
});
if (isClosed) {
const from = points[points.length - 1]
const to = points[0]
const from = points[points.length - 1];
const to = points[0];
if (!from || !to) {
return
return;
}
ctx.moveTo(from.x, from.y)
ctx.lineTo(to.x, to.y)
ctx.moveTo(from.x, from.y);
ctx.lineTo(to.x, to.y);
}
ctx.stroke()
ctx.stroke();
}

View File

@ -8,20 +8,22 @@ export type TDrawDetectionsInput = IRect | IBoundingBox | FaceDetection | WithFa
export function drawDetections(
canvasArg: string | HTMLCanvasElement,
detections: TDrawDetectionsInput | Array<TDrawDetectionsInput>
detections: TDrawDetectionsInput | Array<TDrawDetectionsInput>,
) {
const detectionsArray = Array.isArray(detections) ? detections : [detections]
const detectionsArray = Array.isArray(detections) ? detections : [detections];
detectionsArray.forEach(det => {
detectionsArray.forEach((det) => {
// eslint-disable-next-line no-nested-ternary
const score = det instanceof FaceDetection
? det.score
: (isWithFaceDetection(det) ? det.detection.score : undefined)
: (isWithFaceDetection(det) ? det.detection.score : undefined);
// eslint-disable-next-line no-nested-ternary
const box = det instanceof FaceDetection
? det.box
: (isWithFaceDetection(det) ? det.detection.box : new Box(det))
: (isWithFaceDetection(det) ? det.detection.box : new Box(det));
const label = score ? `${round(score)}` : undefined
new DrawBox(box, { label }).draw(canvasArg)
})
const label = score ? `${round(score)}` : undefined;
new DrawBox(box, { label }).draw(canvasArg);
});
}

View File

@ -11,29 +11,30 @@ export function drawFaceExpressions(
canvasArg: string | HTMLCanvasElement,
faceExpressions: DrawFaceExpressionsInput | Array<DrawFaceExpressionsInput>,
minConfidence = 0.1,
textFieldAnchor?: IPoint
textFieldAnchor?: IPoint,
) {
const faceExpressionsArray = Array.isArray(faceExpressions) ? faceExpressions : [faceExpressions]
const faceExpressionsArray = Array.isArray(faceExpressions) ? faceExpressions : [faceExpressions];
faceExpressionsArray.forEach(e => {
faceExpressionsArray.forEach((e) => {
// eslint-disable-next-line no-nested-ternary
const expr = e instanceof FaceExpressions
? e
: (isWithFaceExpressions(e) ? e.expressions : undefined)
: (isWithFaceExpressions(e) ? e.expressions : undefined);
if (!expr) {
throw new Error('drawFaceExpressions - expected faceExpressions to be FaceExpressions | WithFaceExpressions<{}> or array thereof')
throw new Error('drawFaceExpressions - expected faceExpressions to be FaceExpressions | WithFaceExpressions<{}> or array thereof');
}
const sorted = expr.asSortedArray()
const resultsToDisplay = sorted.filter(expr => expr.probability > minConfidence)
const sorted = expr.asSortedArray();
const resultsToDisplay = sorted.filter((exprLocal) => exprLocal.probability > minConfidence);
const anchor = isWithFaceDetection(e)
? e.detection.box.bottomLeft
: (textFieldAnchor || new Point(0, 0))
: (textFieldAnchor || new Point(0, 0));
const drawTextField = new DrawTextField(
resultsToDisplay.map(expr => `${expr.expression} (${round(expr.probability)})`),
anchor
)
drawTextField.draw(canvasArg)
})
resultsToDisplay.map((exprLocal) => `${exprLocal.expression} (${round(exprLocal.probability)})`),
anchor,
);
drawTextField.draw(canvasArg);
});
}

View File

@ -1,6 +1,6 @@
export * from './drawContour'
export * from './drawDetections'
export * from './drawFaceExpressions'
export * from './DrawBox'
export * from './DrawFaceLandmarks'
export * from './DrawTextField'
export * from './drawContour';
export * from './drawDetections';
export * from './drawFaceExpressions';
export * from './DrawBox';
export * from './DrawFaceLandmarks';
export * from './DrawTextField';

View File

@ -1,24 +1,22 @@
import { Environment } from './types';
export function createBrowserEnv(): Environment {
const fetch = window.fetch;
if (!fetch) throw new Error('fetch - missing fetch implementation for browser environment');
const fetch = window['fetch'] || function() {
throw new Error('fetch - missing fetch implementation for browser environment')
}
const readFile = function() {
throw new Error('readFile - filesystem not available for browser environment')
}
const readFile = () => {
throw new Error('readFile - filesystem not available for browser environment');
};
return {
Canvas: HTMLCanvasElement,
CanvasRenderingContext2D: CanvasRenderingContext2D,
CanvasRenderingContext2D,
Image: HTMLImageElement,
ImageData: ImageData,
ImageData,
Video: HTMLVideoElement,
createCanvasElement: () => document.createElement('canvas'),
createImageElement: () => document.createElement('img'),
fetch,
readFile
}
readFile,
};
}

View File

@ -1,30 +1,26 @@
import { FileSystem } from './types';
export function createFileSystem(fs?: any): FileSystem {
let requireFsError = ''
let requireFsError = '';
if (!fs) {
try {
fs = require('fs')
// eslint-disable-next-line global-require
fs = require('fs');
} catch (err) {
requireFsError = err.toString()
requireFsError = err.toString();
}
}
const readFile = fs
? function(filePath: string) {
return new Promise<Buffer>((res, rej) => {
fs.readFile(filePath, function(err: any, buffer: Buffer) {
return err ? rej(err) : res(buffer)
? (filePath: string) => new Promise<Buffer>((resolve, reject) => {
fs.readFile(filePath, (err: any, buffer: Buffer) => (err ? reject(err) : resolve(buffer)));
})
})
}
: function() {
throw new Error(`readFile - failed to require fs in nodejs environment with error: ${requireFsError}`)
}
: () => {
throw new Error(`readFile - failed to require fs in nodejs environment with error: ${requireFsError}`);
};
return {
readFile
}
readFile,
};
}

View File

@ -1,40 +1,36 @@
/* eslint-disable max-classes-per-file */
import { createFileSystem } from './createFileSystem';
import { Environment } from './types';
export function createNodejsEnv(): Environment {
// eslint-disable-next-line dot-notation
const Canvas = global['Canvas'] || global.HTMLCanvasElement;
const Image = global.Image || global.HTMLImageElement;
const Canvas = global['Canvas'] || global['HTMLCanvasElement']
const Image = global['Image'] || global['HTMLImageElement']
const createCanvasElement = () => {
if (Canvas) return new Canvas();
throw new Error('createCanvasElement - missing Canvas implementation for nodejs environment');
};
const createCanvasElement = function() {
if (Canvas) {
return new Canvas()
}
throw new Error('createCanvasElement - missing Canvas implementation for nodejs environment')
}
const createImageElement = () => {
if (Image) return new Image();
throw new Error('createImageElement - missing Image implementation for nodejs environment');
};
const createImageElement = function() {
if (Image) {
return new Image()
}
throw new Error('createImageElement - missing Image implementation for nodejs environment')
}
const fetch = global.fetch;
// if (!fetch) throw new Error('fetch - missing fetch implementation for nodejs environment');
const fetch = global['fetch'] || function() {
throw new Error('fetch - missing fetch implementation for nodejs environment')
}
const fileSystem = createFileSystem()
const fileSystem = createFileSystem();
return {
Canvas: Canvas || class {},
CanvasRenderingContext2D: global['CanvasRenderingContext2D'] || class {},
CanvasRenderingContext2D: global.CanvasRenderingContext2D || class {},
Image: Image || class {},
ImageData: global['ImageData'] || class {},
Video: global['HTMLVideoElement'] || class {},
ImageData: global.ImageData || class {},
Video: global.HTMLVideoElement || class {},
createCanvasElement,
createImageElement,
fetch,
...fileSystem
}
...fileSystem,
};
}

47
src/env/index.ts vendored
View File

@ -5,49 +5,46 @@ import { isBrowser } from './isBrowser';
import { isNodejs } from './isNodejs';
import { Environment } from './types';
let environment: Environment | null
let environment: Environment | null;
function getEnv(): Environment {
if (!environment) {
throw new Error('getEnv - environment is not defined, check isNodejs() and isBrowser()')
throw new Error('getEnv - environment is not defined, check isNodejs() and isBrowser()');
}
return environment
return environment;
}
function setEnv(env: Environment) {
environment = env
environment = env;
}
function initialize() {
// check for isBrowser() first to prevent electron renderer process
// to be initialized with wrong environment due to isNodejs() returning true
if (isBrowser()) {
return setEnv(createBrowserEnv())
}
if (isNodejs()) {
return setEnv(createNodejsEnv())
}
if (isBrowser()) return setEnv(createBrowserEnv());
if (isNodejs()) return setEnv(createNodejsEnv());
return null;
}
function monkeyPatch(env: Partial<Environment>) {
if (!environment) {
initialize()
initialize();
}
if (!environment) {
throw new Error('monkeyPatch - environment is not defined, check isNodejs() and isBrowser()')
throw new Error('monkeyPatch - environment is not defined, check isNodejs() and isBrowser()');
}
const { Canvas = environment.Canvas, Image = environment.Image } = env
environment.Canvas = Canvas
environment.Image = Image
environment.createCanvasElement = env.createCanvasElement || (() => new Canvas())
environment.createImageElement = env.createImageElement || (() => new Image())
const { Canvas = environment.Canvas, Image = environment.Image } = env;
environment.Canvas = Canvas;
environment.Image = Image;
environment.createCanvasElement = env.createCanvasElement || (() => new Canvas());
environment.createImageElement = env.createImageElement || (() => new Image());
environment.ImageData = env.ImageData || environment.ImageData
environment.Video = env.Video || environment.Video
environment.fetch = env.fetch || environment.fetch
environment.readFile = env.readFile || environment.readFile
environment.ImageData = env.ImageData || environment.ImageData;
environment.Video = env.Video || environment.Video;
environment.fetch = env.fetch || environment.fetch;
environment.readFile = env.readFile || environment.readFile;
}
export const env = {
@ -59,9 +56,9 @@ export const env = {
createNodejsEnv,
monkeyPatch,
isBrowser,
isNodejs
}
isNodejs,
};
initialize()
initialize();
export * from './types'
export * from './types';

View File

@ -5,5 +5,5 @@ export function isBrowser(): boolean {
&& typeof HTMLCanvasElement !== 'undefined'
&& typeof HTMLVideoElement !== 'undefined'
&& typeof ImageData !== 'undefined'
&& typeof CanvasRenderingContext2D !== 'undefined'
&& typeof CanvasRenderingContext2D !== 'undefined';
}

2
src/env/isNodejs.ts vendored
View File

@ -4,5 +4,5 @@ export function isNodejs(): boolean {
&& typeof module !== 'undefined'
// issues with gatsby.js: module.exports is undefined
// && !!module.exports
&& typeof process !== 'undefined' && !!process.version
&& typeof process !== 'undefined' && !!process.version;
}

2
src/env/types.ts vendored
View File

@ -1,4 +1,5 @@
export type FileSystem = {
// eslint-disable-next-line no-unused-vars
readFile: (filePath: string) => Promise<Buffer>
}
@ -10,5 +11,6 @@ export type Environment = FileSystem & {
Video: typeof HTMLVideoElement
createCanvasElement: () => HTMLCanvasElement
createImageElement: () => HTMLImageElement
// eslint-disable-next-line no-undef, no-unused-vars
fetch: (url: string, init?: RequestInit) => Promise<Response>
}

View File

@ -7,46 +7,45 @@ import { FaceProcessor } from '../faceProcessor/FaceProcessor';
import { FaceExpressions } from './FaceExpressions';
export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> {
constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) {
super('FaceExpressionNet', faceFeatureExtractor)
super('FaceExpressionNet', faceFeatureExtractor);
}
public forwardInput(input: NetInput | tf.Tensor4D): tf.Tensor2D {
return tf.tidy(() => tf.softmax(this.runNet(input)))
return tf.tidy(() => tf.softmax(this.runNet(input)));
}
public async forward(input: TNetInput): Promise<tf.Tensor2D> {
return this.forwardInput(await toNetInput(input))
return this.forwardInput(await toNetInput(input));
}
public async predictExpressions(input: TNetInput) {
const netInput = await toNetInput(input)
const out = await this.forwardInput(netInput)
const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async t => {
const data = await t.data()
t.dispose()
return data
}))
out.dispose()
const netInput = await toNetInput(input);
const out = await this.forwardInput(netInput);
const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async (t) => {
const data = await t.data();
t.dispose();
return data;
}));
out.dispose();
const predictionsByBatch = probabilitesByBatch
.map(probabilites => new FaceExpressions(probabilites as Float32Array))
.map((probabilites) => new FaceExpressions(probabilites as Float32Array));
return netInput.isBatchInput
? predictionsByBatch
: predictionsByBatch[0]
: predictionsByBatch[0];
}
protected getDefaultModelName(): string {
return 'face_expression_model'
return 'face_expression_model';
}
protected getClassifierChannelsIn(): number {
return 256
return 256;
}
protected getClassifierChannelsOut(): number {
return 7
return 7;
}
}

View File

@ -1,27 +1,33 @@
export const FACE_EXPRESSION_LABELS = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgusted', 'surprised']
export const FACE_EXPRESSION_LABELS = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgusted', 'surprised'];
export class FaceExpressions {
public neutral: number
public happy: number
public sad: number
public angry: number
public fearful: number
public disgusted: number
public surprised: number
constructor(probabilities: number[] | Float32Array) {
if (probabilities.length !== 7) {
throw new Error(`FaceExpressions.constructor - expected probabilities.length to be 7, have: ${probabilities.length}`)
throw new Error(`FaceExpressions.constructor - expected probabilities.length to be 7, have: ${probabilities.length}`);
}
FACE_EXPRESSION_LABELS.forEach((expression, idx) => {
this[expression] = probabilities[idx]
})
this[expression] = probabilities[idx];
});
}
asSortedArray() {
return FACE_EXPRESSION_LABELS
.map(expression => ({ expression, probability: this[expression] as number }))
.sort((e0, e1) => e1.probability - e0.probability)
.map((expression) => ({ expression, probability: this[expression] as number }))
.sort((e0, e1) => e1.probability - e0.probability);
}
}

View File

@ -9,47 +9,45 @@ import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { FaceFeatureExtractorParams, IFaceFeatureExtractor } from './types';
export class FaceFeatureExtractor extends NeuralNetwork<FaceFeatureExtractorParams> implements IFaceFeatureExtractor<FaceFeatureExtractorParams> {
constructor() {
super('FaceFeatureExtractor')
super('FaceFeatureExtractor');
}
public forwardInput(input: NetInput): tf.Tensor4D {
const { params } = this
const { params } = this;
if (!params) {
throw new Error('FaceFeatureExtractor - load model before inference')
throw new Error('FaceFeatureExtractor - load model before inference');
}
return tf.tidy(() => {
const batchTensor = tf.cast(input.toBatchTensor(112, true), 'float32');
const meanRgb = [122.782, 117.001, 104.298]
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D
const meanRgb = [122.782, 117.001, 104.298];
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D;
let out = denseBlock4(normalized, params.dense0, true)
out = denseBlock4(out, params.dense1)
out = denseBlock4(out, params.dense2)
out = denseBlock4(out, params.dense3)
out = tf.avgPool(out, [7, 7], [2, 2], 'valid')
let out = denseBlock4(normalized, params.dense0, true);
out = denseBlock4(out, params.dense1);
out = denseBlock4(out, params.dense2);
out = denseBlock4(out, params.dense3);
out = tf.avgPool(out, [7, 7], [2, 2], 'valid');
return out
})
return out;
});
}
public async forward(input: TNetInput): Promise<tf.Tensor4D> {
return this.forwardInput(await toNetInput(input))
return this.forwardInput(await toNetInput(input));
}
protected getDefaultModelName(): string {
return 'face_feature_extractor_model'
return 'face_feature_extractor_model';
}
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
return extractParamsFromWeigthMap(weightMap)
return extractParamsFromWeigthMap(weightMap);
}
protected extractParams(weights: Float32Array) {
return extractParams(weights)
return extractParams(weights);
}
}

View File

@ -9,46 +9,44 @@ import { extractParamsTiny } from './extractParamsTiny';
import { IFaceFeatureExtractor, TinyFaceFeatureExtractorParams } from './types';
export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyFaceFeatureExtractorParams> implements IFaceFeatureExtractor<TinyFaceFeatureExtractorParams> {
constructor() {
super('TinyFaceFeatureExtractor')
super('TinyFaceFeatureExtractor');
}
public forwardInput(input: NetInput): tf.Tensor4D {
const { params } = this
const { params } = this;
if (!params) {
throw new Error('TinyFaceFeatureExtractor - load model before inference')
throw new Error('TinyFaceFeatureExtractor - load model before inference');
}
return tf.tidy(() => {
const batchTensor = tf.cast(input.toBatchTensor(112, true), 'float32');
const meanRgb = [122.782, 117.001, 104.298]
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D
const meanRgb = [122.782, 117.001, 104.298];
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D;
let out = denseBlock3(normalized, params.dense0, true)
out = denseBlock3(out, params.dense1)
out = denseBlock3(out, params.dense2)
out = tf.avgPool(out, [14, 14], [2, 2], 'valid')
let out = denseBlock3(normalized, params.dense0, true);
out = denseBlock3(out, params.dense1);
out = denseBlock3(out, params.dense2);
out = tf.avgPool(out, [14, 14], [2, 2], 'valid');
return out
})
return out;
});
}
public async forward(input: TNetInput): Promise<tf.Tensor4D> {
return this.forwardInput(await toNetInput(input))
return this.forwardInput(await toNetInput(input));
}
protected getDefaultModelName(): string {
return 'face_feature_extractor_tiny_model'
return 'face_feature_extractor_tiny_model';
}
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
return extractParamsFromWeigthMapTiny(weightMap)
return extractParamsFromWeigthMapTiny(weightMap);
}
protected extractParams(weights: Float32Array) {
return extractParamsTiny(weights)
return extractParamsTiny(weights);
}
}

View File

@ -7,49 +7,49 @@ import { DenseBlock3Params, DenseBlock4Params } from './types';
export function denseBlock3(
x: tf.Tensor4D,
denseBlockParams: DenseBlock3Params,
isFirstLayer: boolean = false
isFirstLayer: boolean = false,
): tf.Tensor4D {
return tf.tidy(() => {
const out1 = tf.relu(
isFirstLayer
? tf.add(
tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, [2, 2], 'same'),
denseBlockParams.conv0.bias
denseBlockParams.conv0.bias,
)
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, [2, 2])
) as tf.Tensor4D
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1])
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, [2, 2]),
) as tf.Tensor4D;
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]);
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1])
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D;
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1]);
return tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D
})
return tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D;
});
}
export function denseBlock4(
x: tf.Tensor4D,
denseBlockParams: DenseBlock4Params,
isFirstLayer: boolean = false,
isScaleDown: boolean = true
isScaleDown: boolean = true,
): tf.Tensor4D {
return tf.tidy(() => {
const out1 = tf.relu(
isFirstLayer
? tf.add(
tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, isScaleDown ? [2, 2] : [1, 1], 'same'),
denseBlockParams.conv0.bias
denseBlockParams.conv0.bias,
)
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, isScaleDown ? [2, 2] : [1, 1])
) as tf.Tensor4D
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1])
: depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, isScaleDown ? [2, 2] : [1, 1]),
) as tf.Tensor4D;
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]);
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1])
const in3 = tf.relu(tf.add(out1, out2)) as tf.Tensor4D;
const out3 = depthwiseSeparableConv(in3, denseBlockParams.conv2, [1, 1]);
const in4 = tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D
const out4 = depthwiseSeparableConv(in4, denseBlockParams.conv3, [1, 1])
const in4 = tf.relu(tf.add(out1, tf.add(out2, out3))) as tf.Tensor4D;
const out4 = depthwiseSeparableConv(in4, denseBlockParams.conv3, [1, 1]);
return tf.relu(tf.add(out1, tf.add(out2, tf.add(out3, out4)))) as tf.Tensor4D
})
return tf.relu(tf.add(out1, tf.add(out2, tf.add(out3, out4)))) as tf.Tensor4D;
});
}

View File

@ -2,31 +2,31 @@ import { extractWeightsFactory, ParamMapping } from '../common/index';
import { extractorsFactory } from './extractorsFactory';
import { FaceFeatureExtractorParams } from './types';
export function extractParams(weights: Float32Array): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const paramMappings: ParamMapping[] = [];
const {
extractWeights,
getRemainingWeights
} = extractWeightsFactory(weights)
getRemainingWeights,
} = extractWeightsFactory(weights);
const {
extractDenseBlock4Params
} = extractorsFactory(extractWeights, paramMappings)
extractDenseBlock4Params,
} = extractorsFactory(extractWeights, paramMappings);
const dense0 = extractDenseBlock4Params(3, 32, 'dense0', true)
const dense1 = extractDenseBlock4Params(32, 64, 'dense1')
const dense2 = extractDenseBlock4Params(64, 128, 'dense2')
const dense3 = extractDenseBlock4Params(128, 256, 'dense3')
const dense0 = extractDenseBlock4Params(3, 32, 'dense0', true);
const dense1 = extractDenseBlock4Params(32, 64, 'dense1');
const dense2 = extractDenseBlock4Params(64, 128, 'dense2');
const dense3 = extractDenseBlock4Params(128, 256, 'dense3');
if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
}
return {
paramMappings,
params: { dense0, dense1, dense2, dense3 }
}
params: {
dense0, dense1, dense2, dense3,
},
};
}

View File

@ -5,23 +5,22 @@ import { loadParamsFactory } from './loadParamsFactory';
import { FaceFeatureExtractorParams } from './types';
export function extractParamsFromWeigthMap(
weightMap: tf.NamedTensorMap
weightMap: tf.NamedTensorMap,
): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const paramMappings: ParamMapping[] = [];
const {
extractDenseBlock4Params
} = loadParamsFactory(weightMap, paramMappings)
extractDenseBlock4Params,
} = loadParamsFactory(weightMap, paramMappings);
const params = {
dense0: extractDenseBlock4Params('dense0', true),
dense1: extractDenseBlock4Params('dense1'),
dense2: extractDenseBlock4Params('dense2'),
dense3: extractDenseBlock4Params('dense3')
}
dense3: extractDenseBlock4Params('dense3'),
};
disposeUnusedWeightTensors(weightMap, paramMappings)
disposeUnusedWeightTensors(weightMap, paramMappings);
return { params, paramMappings }
return { params, paramMappings };
}

View File

@ -5,22 +5,21 @@ import { loadParamsFactory } from './loadParamsFactory';
import { TinyFaceFeatureExtractorParams } from './types';
export function extractParamsFromWeigthMapTiny(
weightMap: tf.NamedTensorMap
weightMap: tf.NamedTensorMap,
): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const paramMappings: ParamMapping[] = [];
const {
extractDenseBlock3Params
} = loadParamsFactory(weightMap, paramMappings)
extractDenseBlock3Params,
} = loadParamsFactory(weightMap, paramMappings);
const params = {
dense0: extractDenseBlock3Params('dense0', true),
dense1: extractDenseBlock3Params('dense1'),
dense2: extractDenseBlock3Params('dense2')
}
dense2: extractDenseBlock3Params('dense2'),
};
disposeUnusedWeightTensors(weightMap, paramMappings)
disposeUnusedWeightTensors(weightMap, paramMappings);
return { params, paramMappings }
return { params, paramMappings };
}

View File

@ -2,31 +2,28 @@ import { extractWeightsFactory, ParamMapping } from '../common/index';
import { extractorsFactory } from './extractorsFactory';
import { TinyFaceFeatureExtractorParams } from './types';
export function extractParamsTiny(weights: Float32Array): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const paramMappings: ParamMapping[] = [];
const {
extractWeights,
getRemainingWeights
} = extractWeightsFactory(weights)
getRemainingWeights,
} = extractWeightsFactory(weights);
const {
extractDenseBlock3Params
} = extractorsFactory(extractWeights, paramMappings)
extractDenseBlock3Params,
} = extractorsFactory(extractWeights, paramMappings);
const dense0 = extractDenseBlock3Params(3, 32, 'dense0', true)
const dense1 = extractDenseBlock3Params(32, 64, 'dense1')
const dense2 = extractDenseBlock3Params(64, 128, 'dense2')
const dense0 = extractDenseBlock3Params(3, 32, 'dense0', true);
const dense1 = extractDenseBlock3Params(32, 64, 'dense1');
const dense2 = extractDenseBlock3Params(64, 128, 'dense2');
if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
}
return {
paramMappings,
params: { dense0, dense1, dense2 }
}
params: { dense0, dense1, dense2 },
};
}

View File

@ -7,32 +7,30 @@ import {
import { DenseBlock3Params, DenseBlock4Params } from './types';
export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings)
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings);
const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings);
function extractDenseBlock3Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock3Params {
const conv0 = isFirstLayer
? extractConvParams(channelsIn, channelsOut, 3, `${mappedPrefix}/conv0`)
: extractSeparableConvParams(channelsIn, channelsOut, `${mappedPrefix}/conv0`)
const conv1 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv1`)
const conv2 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv2`)
: extractSeparableConvParams(channelsIn, channelsOut, `${mappedPrefix}/conv0`);
const conv1 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv1`);
const conv2 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv2`);
return { conv0, conv1, conv2 }
return { conv0, conv1, conv2 };
}
function extractDenseBlock4Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock4Params {
const { conv0, conv1, conv2 } = extractDenseBlock3Params(channelsIn, channelsOut, mappedPrefix, isFirstLayer);
const conv3 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv3`);
const { conv0, conv1, conv2 } = extractDenseBlock3Params(channelsIn, channelsOut, mappedPrefix, isFirstLayer)
const conv3 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv3`)
return { conv0, conv1, conv2, conv3 }
return {
conv0, conv1, conv2, conv3,
};
}
return {
extractDenseBlock3Params,
extractDenseBlock4Params
}
extractDenseBlock4Params,
};
}

Some files were not shown because too many files have changed in this diff Show More