face-api/src/NeuralNetwork.ts

153 lines
5.2 KiB
TypeScript
Raw Normal View History

2020-12-23 18:58:47 +01:00
import * as tf from '../dist/tfjs.esm';
2020-08-18 13:54:53 +02:00
2020-12-19 17:46:41 +01:00
import { ParamMapping } from './common/index';
2020-08-18 13:54:53 +02:00
import { getModelUris } from './common/getModelUris';
2020-12-19 17:46:41 +01:00
import { loadWeightMap } from './dom/index';
import { env } from './env/index';
2020-08-18 13:54:53 +02:00
export abstract class NeuralNetwork<TNetParams> {
2021-01-24 17:08:04 +01:00
constructor(name: string) {
this._name = name;
}
2020-08-18 13:54:53 +02:00
protected _params: TNetParams | undefined = undefined
2020-12-27 00:35:17 +01:00
2020-08-18 13:54:53 +02:00
protected _paramMappings: ParamMapping[] = []
2020-08-26 00:24:48 +02:00
2021-01-03 17:05:09 +01:00
public _name: any;
2020-12-27 00:35:17 +01:00
public get params(): TNetParams | undefined { return this._params; }
2020-08-26 00:24:48 +02:00
2020-12-27 00:35:17 +01:00
public get paramMappings(): ParamMapping[] { return this._paramMappings; }
public get isLoaded(): boolean { return !!this.params; }
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public getParamFromPath(paramPath: string): tf.Tensor {
2020-12-27 00:35:17 +01:00
const { obj, objProp } = this.traversePropertyPath(paramPath);
return obj[objProp];
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public reassignParamFromPath(paramPath: string, tensor: tf.Tensor) {
2020-12-27 00:35:17 +01:00
const { obj, objProp } = this.traversePropertyPath(paramPath);
obj[objProp].dispose();
obj[objProp] = tensor;
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public getParamList() {
return this._paramMappings.map(({ paramPath }) => ({
path: paramPath,
2020-12-27 00:35:17 +01:00
tensor: this.getParamFromPath(paramPath),
}));
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public getTrainableParams() {
2020-12-27 00:35:17 +01:00
return this.getParamList().filter((param) => param.tensor instanceof tf.Variable);
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public getFrozenParams() {
2020-12-27 00:35:17 +01:00
return this.getParamList().filter((param) => !(param.tensor instanceof tf.Variable));
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public variable() {
this.getFrozenParams().forEach(({ path, tensor }) => {
2020-12-27 00:35:17 +01:00
this.reassignParamFromPath(path, tensor.variable());
});
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public freeze() {
this.getTrainableParams().forEach(({ path, tensor: variable }) => {
2020-12-27 00:35:17 +01:00
const tensor = tf.tensor(variable.dataSync());
variable.dispose();
this.reassignParamFromPath(path, tensor);
});
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public dispose(throwOnRedispose: boolean = true) {
2020-12-27 00:35:17 +01:00
this.getParamList().forEach((param) => {
2020-08-18 13:54:53 +02:00
if (throwOnRedispose && param.tensor.isDisposed) {
2020-12-27 00:35:17 +01:00
throw new Error(`param tensor has already been disposed for path ${param.path}`);
2020-08-18 13:54:53 +02:00
}
2020-12-27 00:35:17 +01:00
param.tensor.dispose();
});
this._params = undefined;
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public serializeParams(): Float32Array {
return new Float32Array(
this.getParamList()
.map(({ tensor }) => Array.from(tensor.dataSync()) as number[])
2020-12-27 00:35:17 +01:00
.reduce((flat, arr) => flat.concat(arr)),
);
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> {
if (weightsOrUrl instanceof Float32Array) {
2020-12-27 00:35:17 +01:00
this.extractWeights(weightsOrUrl);
return;
2020-08-18 13:54:53 +02:00
}
2020-12-27 00:35:17 +01:00
await this.loadFromUri(weightsOrUrl);
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public async loadFromUri(uri: string | undefined) {
if (uri && typeof uri !== 'string') {
2020-12-27 00:35:17 +01:00
throw new Error(`${this._name}.loadFromUri - expected model uri`);
2020-08-18 13:54:53 +02:00
}
2020-12-27 00:35:17 +01:00
const weightMap = await loadWeightMap(uri, this.getDefaultModelName());
this.loadFromWeightMap(weightMap);
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public async loadFromDisk(filePath: string | undefined) {
if (filePath && typeof filePath !== 'string') {
2020-12-27 00:35:17 +01:00
throw new Error(`${this._name}.loadFromDisk - expected model file path`);
2020-08-18 13:54:53 +02:00
}
2020-12-27 00:35:17 +01:00
const { readFile } = env.getEnv();
const { manifestUri, modelBaseUri } = getModelUris(filePath, this.getDefaultModelName());
2021-01-24 17:08:04 +01:00
const fetchWeightsFromDisk = (filePaths: string[]) => Promise.all(filePaths.map((fp) => readFile(fp).then((buf) => buf.buffer)));
2020-12-27 00:35:17 +01:00
const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk);
const manifest = JSON.parse((await readFile(manifestUri)).toString());
const weightMap = await loadWeights(manifest, modelBaseUri);
this.loadFromWeightMap(weightMap);
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public loadFromWeightMap(weightMap: tf.NamedTensorMap) {
2021-01-24 17:08:04 +01:00
const { paramMappings, params } = this.extractParamsFromWeightMap(weightMap);
2020-12-27 00:35:17 +01:00
this._paramMappings = paramMappings;
this._params = params;
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
public extractWeights(weights: Float32Array) {
2021-01-24 17:08:04 +01:00
const { paramMappings, params } = this.extractParams(weights);
2020-12-27 00:35:17 +01:00
this._paramMappings = paramMappings;
this._params = params;
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
private traversePropertyPath(paramPath: string) {
if (!this.params) {
2020-12-27 00:35:17 +01:00
throw new Error('traversePropertyPath - model has no loaded params');
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-08-18 13:54:53 +02:00
const result = paramPath.split('/').reduce((res: { nextObj: any, obj?: any, objProp?: string }, objProp) => {
2020-12-27 00:35:17 +01:00
// eslint-disable-next-line no-prototype-builtins
2020-08-18 13:54:53 +02:00
if (!res.nextObj.hasOwnProperty(objProp)) {
2020-12-27 00:35:17 +01:00
throw new Error(`traversePropertyPath - object does not have property ${objProp}, for path ${paramPath}`);
2020-08-18 13:54:53 +02:00
}
2020-12-27 00:35:17 +01:00
return { obj: res.nextObj, objProp, nextObj: res.nextObj[objProp] };
}, { nextObj: this.params });
2020-08-26 00:24:48 +02:00
2020-12-27 00:35:17 +01:00
const { obj, objProp } = result;
2020-08-18 13:54:53 +02:00
if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
2020-12-27 00:35:17 +01:00
throw new Error(`traversePropertyPath - parameter is not a tensor, for path ${paramPath}`);
2020-08-18 13:54:53 +02:00
}
2020-08-26 00:24:48 +02:00
2020-12-27 00:35:17 +01:00
return { obj, objProp };
2020-08-18 13:54:53 +02:00
}
protected abstract getDefaultModelName(): string
2020-12-27 00:35:17 +01:00
// eslint-disable-next-line no-unused-vars
2021-01-12 16:14:33 +01:00
protected abstract extractParamsFromWeightMap(weightMap: tf.NamedTensorMap): { params: TNetParams, paramMappings: ParamMapping[] }
2020-12-27 00:35:17 +01:00
// eslint-disable-next-line no-unused-vars
2020-08-18 13:54:53 +02:00
protected abstract extractParams(weights: Float32Array): { params: TNetParams, paramMappings: ParamMapping[] }
2020-12-27 00:35:17 +01:00
}