ts linting

pull/34/head
Vladimir Mandic 2020-12-26 18:35:17 -05:00
parent 9c7a8af603
commit 95d3e96b79
8 changed files with 91 additions and 79 deletions

View File

@ -1,3 +1,4 @@
// @ts-nocheck
import * as faceapi from '../dist/face-api.esm.js'; import * as faceapi from '../dist/face-api.esm.js';
// configuration options // configuration options

View File

@ -1,3 +1,5 @@
// @ts-nocheck
const fs = require('fs'); const fs = require('fs');
const path = require('path'); const path = require('path');
const log = require('@vladmandic/pilogger'); const log = require('@vladmandic/pilogger');

View File

@ -1,3 +1,5 @@
// @ts-nocheck
const fs = require('fs'); const fs = require('fs');
const path = require('path'); const path = require('path');
const log = require('@vladmandic/pilogger'); // this is my simple logger with few extra features const log = require('@vladmandic/pilogger'); // this is my simple logger with few extra features

View File

@ -1,3 +1,5 @@
// @ts-nocheck
const fs = require('fs'); const fs = require('fs');
const path = require('path'); const path = require('path');
const log = require('@vladmandic/pilogger'); const log = require('@vladmandic/pilogger');

View File

@ -6,154 +6,160 @@ import { loadWeightMap } from './dom/index';
import { env } from './env/index'; import { env } from './env/index';
export abstract class NeuralNetwork<TNetParams> { export abstract class NeuralNetwork<TNetParams> {
protected _params: TNetParams | undefined = undefined protected _params: TNetParams | undefined = undefined
protected _paramMappings: ParamMapping[] = [] protected _paramMappings: ParamMapping[] = []
constructor(protected _name: string) { private _name: any;
}
public get params(): TNetParams | undefined { return this._params } public get params(): TNetParams | undefined { return this._params; }
public get paramMappings(): ParamMapping[] { return this._paramMappings }
public get isLoaded(): boolean { return !!this.params } public get paramMappings(): ParamMapping[] { return this._paramMappings; }
public get isLoaded(): boolean { return !!this.params; }
public getParamFromPath(paramPath: string): tf.Tensor { public getParamFromPath(paramPath: string): tf.Tensor {
const { obj, objProp } = this.traversePropertyPath(paramPath) const { obj, objProp } = this.traversePropertyPath(paramPath);
return obj[objProp] return obj[objProp];
} }
public reassignParamFromPath(paramPath: string, tensor: tf.Tensor) { public reassignParamFromPath(paramPath: string, tensor: tf.Tensor) {
const { obj, objProp } = this.traversePropertyPath(paramPath) const { obj, objProp } = this.traversePropertyPath(paramPath);
obj[objProp].dispose() obj[objProp].dispose();
obj[objProp] = tensor obj[objProp] = tensor;
} }
public getParamList() { public getParamList() {
return this._paramMappings.map(({ paramPath }) => ({ return this._paramMappings.map(({ paramPath }) => ({
path: paramPath, path: paramPath,
tensor: this.getParamFromPath(paramPath) tensor: this.getParamFromPath(paramPath),
})) }));
} }
public getTrainableParams() { public getTrainableParams() {
return this.getParamList().filter(param => param.tensor instanceof tf.Variable) return this.getParamList().filter((param) => param.tensor instanceof tf.Variable);
} }
public getFrozenParams() { public getFrozenParams() {
return this.getParamList().filter(param => !(param.tensor instanceof tf.Variable)) return this.getParamList().filter((param) => !(param.tensor instanceof tf.Variable));
} }
public variable() { public variable() {
this.getFrozenParams().forEach(({ path, tensor }) => { this.getFrozenParams().forEach(({ path, tensor }) => {
this.reassignParamFromPath(path, tensor.variable()) this.reassignParamFromPath(path, tensor.variable());
}) });
} }
public freeze() { public freeze() {
this.getTrainableParams().forEach(({ path, tensor: variable }) => { this.getTrainableParams().forEach(({ path, tensor: variable }) => {
const tensor = tf.tensor(variable.dataSync()) const tensor = tf.tensor(variable.dataSync());
variable.dispose() variable.dispose();
this.reassignParamFromPath(path, tensor) this.reassignParamFromPath(path, tensor);
}) });
} }
public dispose(throwOnRedispose: boolean = true) { public dispose(throwOnRedispose: boolean = true) {
this.getParamList().forEach(param => { this.getParamList().forEach((param) => {
if (throwOnRedispose && param.tensor.isDisposed) { if (throwOnRedispose && param.tensor.isDisposed) {
throw new Error(`param tensor has already been disposed for path ${param.path}`) throw new Error(`param tensor has already been disposed for path ${param.path}`);
} }
param.tensor.dispose() param.tensor.dispose();
}) });
this._params = undefined this._params = undefined;
} }
public serializeParams(): Float32Array { public serializeParams(): Float32Array {
return new Float32Array( return new Float32Array(
this.getParamList() this.getParamList()
.map(({ tensor }) => Array.from(tensor.dataSync()) as number[]) .map(({ tensor }) => Array.from(tensor.dataSync()) as number[])
.reduce((flat, arr) => flat.concat(arr)) .reduce((flat, arr) => flat.concat(arr)),
) );
} }
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> { public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> {
if (weightsOrUrl instanceof Float32Array) { if (weightsOrUrl instanceof Float32Array) {
this.extractWeights(weightsOrUrl) this.extractWeights(weightsOrUrl);
return return;
} }
await this.loadFromUri(weightsOrUrl) await this.loadFromUri(weightsOrUrl);
} }
public async loadFromUri(uri: string | undefined) { public async loadFromUri(uri: string | undefined) {
if (uri && typeof uri !== 'string') { if (uri && typeof uri !== 'string') {
throw new Error(`${this._name}.loadFromUri - expected model uri`) throw new Error(`${this._name}.loadFromUri - expected model uri`);
} }
const weightMap = await loadWeightMap(uri, this.getDefaultModelName()) const weightMap = await loadWeightMap(uri, this.getDefaultModelName());
this.loadFromWeightMap(weightMap) this.loadFromWeightMap(weightMap);
} }
public async loadFromDisk(filePath: string | undefined) { public async loadFromDisk(filePath: string | undefined) {
if (filePath && typeof filePath !== 'string') { if (filePath && typeof filePath !== 'string') {
throw new Error(`${this._name}.loadFromDisk - expected model file path`) throw new Error(`${this._name}.loadFromDisk - expected model file path`);
} }
const { readFile } = env.getEnv() const { readFile } = env.getEnv();
const { manifestUri, modelBaseUri } = getModelUris(filePath, this.getDefaultModelName()) const { manifestUri, modelBaseUri } = getModelUris(filePath, this.getDefaultModelName());
const fetchWeightsFromDisk = (filePaths: string[]) => Promise.all( const fetchWeightsFromDisk = (filePaths: string[]) => Promise.all(
filePaths.map(filePath => readFile(filePath).then(buf => buf.buffer)) filePaths.map((fp) => readFile(fp).then((buf) => buf.buffer)),
) );
const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk) const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk);
const manifest = JSON.parse((await readFile(manifestUri)).toString()) const manifest = JSON.parse((await readFile(manifestUri)).toString());
const weightMap = await loadWeights(manifest, modelBaseUri) const weightMap = await loadWeights(manifest, modelBaseUri);
this.loadFromWeightMap(weightMap) this.loadFromWeightMap(weightMap);
} }
public loadFromWeightMap(weightMap: tf.NamedTensorMap) { public loadFromWeightMap(weightMap: tf.NamedTensorMap) {
const { const {
paramMappings, paramMappings,
params params,
} = this.extractParamsFromWeigthMap(weightMap) } = this.extractParamsFromWeigthMap(weightMap);
this._paramMappings = paramMappings this._paramMappings = paramMappings;
this._params = params this._params = params;
} }
public extractWeights(weights: Float32Array) { public extractWeights(weights: Float32Array) {
const { const {
paramMappings, paramMappings,
params params,
} = this.extractParams(weights) } = this.extractParams(weights);
this._paramMappings = paramMappings this._paramMappings = paramMappings;
this._params = params this._params = params;
} }
private traversePropertyPath(paramPath: string) { private traversePropertyPath(paramPath: string) {
if (!this.params) { if (!this.params) {
throw new Error(`traversePropertyPath - model has no loaded params`) throw new Error('traversePropertyPath - model has no loaded params');
} }
const result = paramPath.split('/').reduce((res: { nextObj: any, obj?: any, objProp?: string }, objProp) => { const result = paramPath.split('/').reduce((res: { nextObj: any, obj?: any, objProp?: string }, objProp) => {
// eslint-disable-next-line no-prototype-builtins
if (!res.nextObj.hasOwnProperty(objProp)) { if (!res.nextObj.hasOwnProperty(objProp)) {
throw new Error(`traversePropertyPath - object does not have property ${objProp}, for path ${paramPath}`) throw new Error(`traversePropertyPath - object does not have property ${objProp}, for path ${paramPath}`);
} }
return { obj: res.nextObj, objProp, nextObj: res.nextObj[objProp] } return { obj: res.nextObj, objProp, nextObj: res.nextObj[objProp] };
}, { nextObj: this.params }) }, { nextObj: this.params });
const { obj, objProp } = result const { obj, objProp } = result;
if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) { if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
throw new Error(`traversePropertyPath - parameter is not a tensor, for path ${paramPath}`) throw new Error(`traversePropertyPath - parameter is not a tensor, for path ${paramPath}`);
} }
return { obj, objProp } return { obj, objProp };
} }
protected abstract getDefaultModelName(): string protected abstract getDefaultModelName(): string
// eslint-disable-next-line no-unused-vars
protected abstract extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap): { params: TNetParams, paramMappings: ParamMapping[] } protected abstract extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap): { params: TNetParams, paramMappings: ParamMapping[] }
// eslint-disable-next-line no-unused-vars
protected abstract extractParams(weights: Float32Array): { params: TNetParams, paramMappings: ParamMapping[] } protected abstract extractParams(weights: Float32Array): { params: TNetParams, paramMappings: ParamMapping[] }
} }

View File

@ -1,6 +1,7 @@
export class PlatformBrowser { export class PlatformBrowser {
private textEncoder: TextEncoder; private textEncoder: TextEncoder;
// eslint-disable-next-line no-undef
fetch(path: string, init?: RequestInit): Promise<Response> { fetch(path: string, init?: RequestInit): Promise<Response> {
return fetch(path, init); return fetch(path, init);
} }
@ -11,14 +12,14 @@ export class PlatformBrowser {
encode(text: string, encoding: string): Uint8Array { encode(text: string, encoding: string): Uint8Array {
if (encoding !== 'utf-8' && encoding !== 'utf8') { if (encoding !== 'utf-8' && encoding !== 'utf8') {
throw new Error( throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`);
`Browser's encoder only supports utf-8, but got ${encoding}`);
} }
if (this.textEncoder == null) { if (this.textEncoder == null) {
this.textEncoder = new TextEncoder(); this.textEncoder = new TextEncoder();
} }
return this.textEncoder.encode(text); return this.textEncoder.encode(text);
} }
decode(bytes: Uint8Array, encoding: string): string { decode(bytes: Uint8Array, encoding: string): string {
return new TextDecoder(encoding).decode(bytes); return new TextDecoder(encoding).decode(bytes);
} }

View File

@ -1,13 +1,12 @@
export function euclideanDistance(arr1: number[] | Float32Array, arr2: number[] | Float32Array) { export function euclideanDistance(arr1: number[] | Float32Array, arr2: number[] | Float32Array) {
if (arr1.length !== arr2.length) if (arr1.length !== arr2.length) throw new Error('euclideanDistance: arr1.length !== arr2.length');
throw new Error('euclideanDistance: arr1.length !== arr2.length')
const desc1 = Array.from(arr1) const desc1 = Array.from(arr1);
const desc2 = Array.from(arr2) const desc2 = Array.from(arr2);
return Math.sqrt( return Math.sqrt(
desc1 desc1
.map((val, i) => val - desc2[i]) .map((val, i) => val - desc2[i])
.reduce((res, diff) => res + Math.pow(diff, 2), 0) .reduce((res, diff) => res + (diff ** 2), 0),
) );
} }

View File

@ -5,31 +5,30 @@ import { extendWithFaceDetection, isWithFaceDetection } from './factories/WithFa
import { extendWithFaceLandmarks, isWithFaceLandmarks } from './factories/WithFaceLandmarks'; import { extendWithFaceLandmarks, isWithFaceLandmarks } from './factories/WithFaceLandmarks';
export function resizeResults<T>(results: T, dimensions: IDimensions): T { export function resizeResults<T>(results: T, dimensions: IDimensions): T {
const { width, height } = new Dimensions(dimensions.width, dimensions.height);
const { width, height } = new Dimensions(dimensions.width, dimensions.height)
if (width <= 0 || height <= 0) { if (width <= 0 || height <= 0) {
throw new Error(`resizeResults - invalid dimensions: ${JSON.stringify({ width, height })}`) throw new Error(`resizeResults - invalid dimensions: ${JSON.stringify({ width, height })}`);
} }
if (Array.isArray(results)) { if (Array.isArray(results)) {
// return results.map(obj => resizeResults(obj, { width, height })) as any as T // return results.map(obj => resizeResults(obj, { width, height })) as any as T
return (results as Array<any>).map(obj => resizeResults(obj, { width, height } as IDimensions)) as any as T return (results as Array<any>).map((obj) => resizeResults(obj, { width, height } as IDimensions)) as any as T;
} }
if (isWithFaceLandmarks(results)) { if (isWithFaceLandmarks(results)) {
const resizedDetection = results.detection.forSize(width, height) const resizedDetection = results.detection.forSize(width, height);
const resizedLandmarks = results.unshiftedLandmarks.forSize(resizedDetection.box.width, resizedDetection.box.height) const resizedLandmarks = results.unshiftedLandmarks.forSize(resizedDetection.box.width, resizedDetection.box.height);
return extendWithFaceLandmarks(extendWithFaceDetection(results, resizedDetection), resizedLandmarks) return extendWithFaceLandmarks(extendWithFaceDetection(results, resizedDetection), resizedLandmarks);
} }
if (isWithFaceDetection(results)) { if (isWithFaceDetection(results)) {
return extendWithFaceDetection(results, results.detection.forSize(width, height)) return extendWithFaceDetection(results, results.detection.forSize(width, height));
} }
if (results instanceof FaceLandmarks || results instanceof FaceDetection) { if (results instanceof FaceLandmarks || results instanceof FaceDetection) {
return (results as any).forSize(width, height) return (results as any).forSize(width, height);
} }
return results return results;
} }