2020-08-18 14:04:33 +02:00
|
|
|
import * as tf from '@tensorflow/tfjs-core';
|
|
|
|
import { ParamMapping } from './common';
|
|
|
|
export declare abstract class NeuralNetwork<TNetParams> {
|
|
|
|
protected _name: string;
|
|
|
|
protected _params: TNetParams | undefined;
|
|
|
|
protected _paramMappings: ParamMapping[];
|
|
|
|
constructor(_name: string);
|
|
|
|
get params(): TNetParams | undefined;
|
|
|
|
get paramMappings(): ParamMapping[];
|
|
|
|
get isLoaded(): boolean;
|
|
|
|
getParamFromPath(paramPath: string): tf.Tensor;
|
|
|
|
reassignParamFromPath(paramPath: string, tensor: tf.Tensor): void;
|
|
|
|
getParamList(): {
|
|
|
|
path: string;
|
|
|
|
tensor: tf.Tensor<tf.Rank>;
|
|
|
|
}[];
|
|
|
|
getTrainableParams(): {
|
|
|
|
path: string;
|
|
|
|
tensor: tf.Tensor<tf.Rank>;
|
|
|
|
}[];
|
|
|
|
getFrozenParams(): {
|
|
|
|
path: string;
|
|
|
|
tensor: tf.Tensor<tf.Rank>;
|
|
|
|
}[];
|
|
|
|
variable(): void;
|
|
|
|
freeze(): void;
|
|
|
|
dispose(throwOnRedispose?: boolean): void;
|
|
|
|
serializeParams(): Float32Array;
|
|
|
|
load(weightsOrUrl: Float32Array | string | undefined): Promise<void>;
|
|
|
|
loadFromUri(uri: string | undefined): Promise<void>;
|
|
|
|
loadFromDisk(filePath: string | undefined): Promise<void>;
|
|
|
|
loadFromWeightMap(weightMap: tf.NamedTensorMap): void;
|
|
|
|
extractWeights(weights: Float32Array): void;
|
|
|
|
private traversePropertyPath;
|
|
|
|
protected abstract getDefaultModelName(): string;
|
|
|
|
protected abstract extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap): {
|
|
|
|
params: TNetParams;
|
|
|
|
paramMappings: ParamMapping[];
|
|
|
|
};
|
|
|
|
protected abstract extractParams(weights: Float32Array): {
|
|
|
|
params: TNetParams;
|
|
|
|
paramMappings: ParamMapping[];
|
|
|
|
};
|
|
|
|
}
|
2020-08-21 15:01:04 +02:00
|
|
|
//# sourceMappingURL=NeuralNetwork.d.ts.map
|