human/src/tfjs/load.ts

82 lines
3.4 KiB
TypeScript
Raw Normal View History

2022-01-17 17:03:21 +01:00
import { log, join } from '../util/util';
2022-01-16 15:49:55 +01:00
import * as tf from '../../dist/tfjs.esm.js';
import type { GraphModel } from './types';
2022-01-17 17:03:21 +01:00
import type { Config } from '../config';
2022-01-16 15:49:55 +01:00
2022-01-17 17:03:21 +01:00
const options = {
2022-07-13 14:53:37 +02:00
cacheModels: true,
cacheSupported: true,
2022-01-16 15:49:55 +01:00
verbose: true,
2022-01-17 17:03:21 +01:00
debug: false,
modelBasePath: '',
2022-01-16 15:49:55 +01:00
};
2022-07-02 09:39:40 +02:00
type ModelStats = {
name: string,
cached: boolean,
manifest: number,
weights: number,
}
export const modelStats: Record<string, ModelStats> = {};
2022-01-17 17:03:21 +01:00
async function httpHandler(url, init?): Promise<Response | null> {
if (options.debug) log('load model fetch:', url, init);
return fetch(url, init);
2022-01-16 15:49:55 +01:00
}
2022-01-17 17:03:21 +01:00
export function setModelLoadOptions(config: Config) {
options.cacheModels = config.cacheModels;
options.verbose = config.debug;
options.modelBasePath = config.modelBasePath;
}
2022-01-16 15:49:55 +01:00
2022-01-17 17:03:21 +01:00
export async function loadModel(modelPath: string | undefined): Promise<GraphModel> {
2022-05-22 14:50:51 +02:00
let modelUrl = join(options.modelBasePath, modelPath || '');
if (!modelUrl.toLowerCase().endsWith('.json')) modelUrl += '.json';
2022-01-17 17:03:21 +01:00
const modelPathSegments = modelUrl.split('/');
2022-07-02 09:39:40 +02:00
const shortModelName = modelPathSegments[modelPathSegments.length - 1].replace('.json', '');
const cachedModelName = 'indexeddb://' + shortModelName; // generate short model name for cache
modelStats[shortModelName] = {
name: shortModelName,
manifest: 0,
weights: 0,
cached: false,
};
2022-07-13 14:53:37 +02:00
options.cacheSupported = (typeof window !== 'undefined') && (typeof window.localStorage !== 'undefined') && (typeof window.indexedDB !== 'undefined'); // check if running in browser and if indexedb is available
2022-07-14 15:36:08 +02:00
let cachedModels = {};
try {
cachedModels = (options.cacheSupported && options.cacheModels) ? await tf.io.listModels() : {}; // list all models already in cache // this fails for webview although localStorage is defined
} catch {
options.cacheSupported = false;
}
2022-07-13 14:53:37 +02:00
modelStats[shortModelName].cached = (options.cacheSupported && options.cacheModels) && Object.keys(cachedModels).includes(cachedModelName); // is model found in cache
2022-01-20 14:17:06 +01:00
const tfLoadOptions = typeof fetch === 'undefined' ? {} : { fetchFunc: (url, init?) => httpHandler(url, init) };
2022-07-02 09:39:40 +02:00
const model: GraphModel = new tf.GraphModel(modelStats[shortModelName].cached ? cachedModelName : modelUrl, tfLoadOptions) as unknown as GraphModel; // create model prototype and decide if load from cache or from original modelurl
2022-04-10 16:13:13 +02:00
let loaded = false;
2022-01-17 17:03:21 +01:00
try {
// @ts-ignore private function
model.findIOHandler(); // decide how to actually load a model
2022-07-02 09:39:40 +02:00
if (options.debug) log('model load handler:', model['handler']);
2022-01-17 17:03:21 +01:00
// @ts-ignore private property
const artifacts = await model.handler.load(); // load manifest
2022-07-02 09:39:40 +02:00
modelStats[shortModelName].manifest = artifacts?.weightData?.byteLength || 0;
2022-01-17 17:03:21 +01:00
model.loadSync(artifacts); // load weights
2022-07-02 09:39:40 +02:00
// @ts-ignore private property
modelStats[shortModelName].weights = model?.artifacts?.weightData?.byteLength || 0;
2022-07-13 14:53:37 +02:00
if (options.verbose) log('load model:', model['modelUrl'], { bytes: modelStats[shortModelName].weights }, options);
2022-04-10 16:13:13 +02:00
loaded = true;
2022-01-17 17:03:21 +01:00
} catch (err) {
log('error loading model:', modelUrl, err);
}
2022-07-13 14:53:37 +02:00
if (loaded && options.cacheModels && options.cacheSupported && !modelStats[shortModelName].cached) { // save model to cache
2022-01-17 17:03:21 +01:00
try {
const saveResult = await model.save(cachedModelName);
log('model saved:', cachedModelName, saveResult);
} catch (err) {
log('error saving model:', modelUrl, err);
}
}
2022-01-16 15:49:55 +01:00
return model;
}