2022-01-17 17:03:21 +01:00
import { log , join } from '../util/util' ;
2022-01-16 15:49:55 +01:00
import * as tf from '../../dist/tfjs.esm.js' ;
import type { GraphModel } from './types' ;
2022-01-17 17:03:21 +01:00
import type { Config } from '../config' ;
2022-07-18 03:31:08 +02:00
import * as modelsDefs from '../../models/models.json' ;
2022-08-10 19:44:38 +02:00
import { validateModel } from '../models' ;
2022-01-16 15:49:55 +01:00
2022-01-17 17:03:21 +01:00
const options = {
2022-07-13 14:53:37 +02:00
cacheModels : true ,
cacheSupported : true ,
2022-01-16 15:49:55 +01:00
verbose : true ,
2022-01-17 17:03:21 +01:00
debug : false ,
modelBasePath : '' ,
2022-01-16 15:49:55 +01:00
} ;
2022-08-21 19:34:51 +02:00
export interface ModelInfo {
2022-07-02 09:39:40 +02:00
name : string ,
2022-07-18 03:31:08 +02:00
inCache : boolean ,
sizeDesired : number ,
sizeFromManifest : number ,
sizeLoadedWeights : number ,
2022-07-02 09:39:40 +02:00
}
2022-07-18 03:31:08 +02:00
export const modelStats : Record < string , ModelInfo > = { } ;
2022-07-02 09:39:40 +02:00
2022-08-21 21:23:03 +02:00
async function httpHandler ( url : string , init? : RequestInit ) : Promise < Response | null > {
2022-01-17 17:03:21 +01:00
if ( options . debug ) log ( 'load model fetch:' , url , init ) ;
return fetch ( url , init ) ;
2022-01-16 15:49:55 +01:00
}
2022-01-17 17:03:21 +01:00
export function setModelLoadOptions ( config : Config ) {
options . cacheModels = config . cacheModels ;
options . verbose = config . debug ;
options . modelBasePath = config . modelBasePath ;
}
2022-01-16 15:49:55 +01:00
2022-01-17 17:03:21 +01:00
export async function loadModel ( modelPath : string | undefined ) : Promise < GraphModel > {
2022-05-22 14:50:51 +02:00
let modelUrl = join ( options . modelBasePath , modelPath || '' ) ;
if ( ! modelUrl . toLowerCase ( ) . endsWith ( '.json' ) ) modelUrl += '.json' ;
2022-07-18 03:31:08 +02:00
const modelPathSegments = modelUrl . includes ( '/' ) ? modelUrl . split ( '/' ) : modelUrl . split ( '\\' ) ;
2022-07-02 09:39:40 +02:00
const shortModelName = modelPathSegments [ modelPathSegments . length - 1 ] . replace ( '.json' , '' ) ;
const cachedModelName = 'indexeddb://' + shortModelName ; // generate short model name for cache
modelStats [ shortModelName ] = {
name : shortModelName ,
2022-07-18 03:31:08 +02:00
sizeFromManifest : 0 ,
sizeLoadedWeights : 0 ,
sizeDesired : modelsDefs [ shortModelName ] ,
inCache : false ,
2022-07-02 09:39:40 +02:00
} ;
2022-07-13 14:53:37 +02:00
options . cacheSupported = ( typeof window !== 'undefined' ) && ( typeof window . localStorage !== 'undefined' ) && ( typeof window . indexedDB !== 'undefined' ) ; // check if running in browser and if indexedb is available
2022-07-14 15:36:08 +02:00
let cachedModels = { } ;
try {
cachedModels = ( options . cacheSupported && options . cacheModels ) ? await tf . io . listModels ( ) : { } ; // list all models already in cache // this fails for webview although localStorage is defined
} catch {
options . cacheSupported = false ;
}
2022-07-18 03:31:08 +02:00
modelStats [ shortModelName ] . inCache = ( options . cacheSupported && options . cacheModels ) && Object . keys ( cachedModels ) . includes ( cachedModelName ) ; // is model found in cache
2022-08-21 21:23:03 +02:00
const tfLoadOptions = typeof fetch === 'undefined' ? { } : { fetchFunc : ( url : string , init? : RequestInit ) = > httpHandler ( url , init ) } ;
2022-07-18 03:31:08 +02:00
const model : GraphModel = new tf . GraphModel ( modelStats [ shortModelName ] . inCache ? cachedModelName : modelUrl , tfLoadOptions ) as unknown as GraphModel ; // create model prototype and decide if load from cache or from original modelurl
2022-04-10 16:13:13 +02:00
let loaded = false ;
2022-01-17 17:03:21 +01:00
try {
// @ts-ignore private function
model . findIOHandler ( ) ; // decide how to actually load a model
2022-07-02 09:39:40 +02:00
if ( options . debug ) log ( 'model load handler:' , model [ 'handler' ] ) ;
2022-01-17 17:03:21 +01:00
// @ts-ignore private property
const artifacts = await model . handler . load ( ) ; // load manifest
2022-07-18 03:31:08 +02:00
modelStats [ shortModelName ] . sizeFromManifest = artifacts ? . weightData ? . byteLength || 0 ;
2022-01-17 17:03:21 +01:00
model . loadSync ( artifacts ) ; // load weights
2022-07-02 09:39:40 +02:00
// @ts-ignore private property
2022-08-21 19:34:51 +02:00
modelStats [ shortModelName ] . sizeLoadedWeights = model . artifacts ? . weightData ? . byteLength || 0 ;
2022-09-02 20:07:10 +02:00
if ( options . verbose ) log ( 'load:' , { model : shortModelName , url : model [ 'modelUrl' ] , bytes : modelStats [ shortModelName ] . sizeLoadedWeights } ) ;
2022-04-10 16:13:13 +02:00
loaded = true ;
2022-01-17 17:03:21 +01:00
} catch ( err ) {
log ( 'error loading model:' , modelUrl , err ) ;
}
2022-07-18 03:31:08 +02:00
if ( loaded && options . cacheModels && options . cacheSupported && ! modelStats [ shortModelName ] . inCache ) { // save model to cache
2022-01-17 17:03:21 +01:00
try {
const saveResult = await model . save ( cachedModelName ) ;
log ( 'model saved:' , cachedModelName , saveResult ) ;
} catch ( err ) {
log ( 'error saving model:' , modelUrl , err ) ;
}
}
2022-08-21 21:23:03 +02:00
validateModel ( null , model , ` ${ modelPath || '' } ` ) ;
2022-01-16 15:49:55 +01:00
return model ;
}