2022-01-17 17:03:21 +01:00
import { log , join } from '../util/util' ;
2022-01-16 15:49:55 +01:00
import * as tf from '../../dist/tfjs.esm.js' ;
import type { GraphModel } from './types' ;
2022-01-17 17:03:21 +01:00
import type { Config } from '../config' ;
2022-01-16 15:49:55 +01:00
2022-01-17 17:03:21 +01:00
const options = {
cacheModels : false ,
2022-01-16 15:49:55 +01:00
verbose : true ,
2022-01-17 17:03:21 +01:00
debug : false ,
modelBasePath : '' ,
2022-01-16 15:49:55 +01:00
} ;
2022-01-17 17:03:21 +01:00
async function httpHandler ( url , init ? ) : Promise < Response | null > {
if ( options . debug ) log ( 'load model fetch:' , url , init ) ;
return fetch ( url , init ) ;
2022-01-16 15:49:55 +01:00
}
2022-01-17 17:03:21 +01:00
export function setModelLoadOptions ( config : Config ) {
options . cacheModels = config . cacheModels ;
options . verbose = config . debug ;
options . modelBasePath = config . modelBasePath ;
}
2022-01-16 15:49:55 +01:00
2022-01-17 17:03:21 +01:00
export async function loadModel ( modelPath : string | undefined ) : Promise < GraphModel > {
2022-05-22 14:50:51 +02:00
let modelUrl = join ( options . modelBasePath , modelPath || '' ) ;
if ( ! modelUrl . toLowerCase ( ) . endsWith ( '.json' ) ) modelUrl += '.json' ;
2022-01-17 17:03:21 +01:00
const modelPathSegments = modelUrl . split ( '/' ) ;
const cachedModelName = 'indexeddb://' + modelPathSegments [ modelPathSegments . length - 1 ] . replace ( '.json' , '' ) ; // generate short model name for cache
const cachedModels = await tf . io . listModels ( ) ; // list all models already in cache
const modelCached = options . cacheModels && Object . keys ( cachedModels ) . includes ( cachedModelName ) ; // is model found in cache
2022-01-20 14:17:06 +01:00
const tfLoadOptions = typeof fetch === 'undefined' ? { } : { fetchFunc : ( url , init ? ) = > httpHandler ( url , init ) } ;
const model : GraphModel = new tf . GraphModel ( modelCached ? cachedModelName : modelUrl , tfLoadOptions ) as unknown as GraphModel ; // create model prototype and decide if load from cache or from original modelurl
2022-04-10 16:13:13 +02:00
let loaded = false ;
2022-01-17 17:03:21 +01:00
try {
// @ts-ignore private function
model . findIOHandler ( ) ; // decide how to actually load a model
// @ts-ignore private property
if ( options . debug ) log ( 'model load handler:' , model . handler ) ;
// @ts-ignore private property
const artifacts = await model . handler . load ( ) ; // load manifest
model . loadSync ( artifacts ) ; // load weights
if ( options . verbose ) log ( 'load model:' , model [ 'modelUrl' ] ) ;
2022-04-10 16:13:13 +02:00
loaded = true ;
2022-01-17 17:03:21 +01:00
} catch ( err ) {
log ( 'error loading model:' , modelUrl , err ) ;
}
2022-04-10 16:13:13 +02:00
if ( loaded && options . cacheModels && ! modelCached ) { // save model to cache
2022-01-17 17:03:21 +01:00
try {
const saveResult = await model . save ( cachedModelName ) ;
log ( 'model saved:' , cachedModelName , saveResult ) ;
} catch ( err ) {
log ( 'error saving model:' , modelUrl , err ) ;
}
}
2022-01-16 15:49:55 +01:00
return model ;
}