2021-09-25 17:51:15 +02:00
/** TFJS backend initialization and customization */
2021-11-12 21:07:23 +01:00
import type { Human } from '../human' ;
2021-09-27 19:58:13 +02:00
import { log , now } from '../util/util' ;
2021-10-21 16:26:44 +02:00
import { env } from '../util/env' ;
2021-09-13 00:37:06 +02:00
import * as humangl from './humangl' ;
2020-12-13 00:34:30 +01:00
import * as tf from '../../dist/tfjs.esm.js' ;
2021-11-17 02:16:49 +01:00
import * as constants from './constants' ;
2020-12-13 00:34:30 +01:00
2021-11-05 18:36:53 +01:00
function registerCustomOps() {
if ( ! env . kernels . includes ( 'mod' ) ) {
const kernelMod = {
kernelName : 'Mod' ,
backendName : tf.getBackend ( ) ,
kernelFunc : ( op ) = > tf . tidy ( ( ) = > tf . sub ( op . inputs . a , tf . mul ( tf . div ( op . inputs . a , op . inputs . b ) , op . inputs . b ) ) ) ,
} ;
tf . registerKernel ( kernelMod ) ;
2021-11-06 15:21:51 +01:00
env . kernels . push ( 'mod' ) ;
2021-11-05 18:36:53 +01:00
}
if ( ! env . kernels . includes ( 'floormod' ) ) {
const kernelMod = {
kernelName : 'FloorMod' ,
backendName : tf.getBackend ( ) ,
kernelFunc : ( op ) = > tf . tidy ( ( ) = > tf . floorDiv ( op . inputs . a / op . inputs . b ) * op . inputs . b + tf . mod ( op . inputs . a , op . inputs . b ) ) ,
} ;
tf . registerKernel ( kernelMod ) ;
2021-11-06 15:21:51 +01:00
env . kernels . push ( 'floormod' ) ;
2021-11-05 18:36:53 +01:00
}
}
2021-11-12 21:07:23 +01:00
export async function check ( instance : Human , force = false ) {
2021-09-21 03:59:49 +02:00
instance . state = 'backend' ;
2021-10-21 16:26:44 +02:00
if ( force || env . initial || ( instance . config . backend && ( instance . config . backend . length > 0 ) && ( tf . getBackend ( ) !== instance . config . backend ) ) ) {
2021-09-13 00:37:06 +02:00
const timeStamp = now ( ) ;
2020-12-13 00:34:30 +01:00
2021-09-13 00:37:06 +02:00
if ( instance . config . backend && instance . config . backend . length > 0 ) {
// detect web worker
// @ts-ignore ignore missing type for WorkerGlobalScope as that is the point
if ( typeof window === 'undefined' && typeof WorkerGlobalScope !== 'undefined' && instance . config . debug ) {
2021-09-17 17:23:00 +02:00
if ( instance . config . debug ) log ( 'running inside web worker' ) ;
2021-09-13 00:37:06 +02:00
}
2021-06-11 22:12:24 +02:00
2021-09-13 00:37:06 +02:00
// force browser vs node backend
2021-10-21 16:26:44 +02:00
if ( env . browser && instance . config . backend === 'tensorflow' ) {
2021-09-17 17:23:00 +02:00
if ( instance . config . debug ) log ( 'override: backend set to tensorflow while running in browser' ) ;
2021-09-13 00:37:06 +02:00
instance . config . backend = 'humangl' ;
}
2021-10-21 16:26:44 +02:00
if ( env . node && ( instance . config . backend === 'webgl' || instance . config . backend === 'humangl' ) ) {
2021-09-17 17:23:00 +02:00
if ( instance . config . debug ) log ( ` override: backend set to ${ instance . config . backend } while running in nodejs ` ) ;
2021-09-13 00:37:06 +02:00
instance . config . backend = 'tensorflow' ;
}
// handle webgpu
2021-10-21 16:26:44 +02:00
if ( env . browser && instance . config . backend === 'webgpu' ) {
2021-09-13 00:37:06 +02:00
if ( typeof navigator === 'undefined' || typeof navigator [ 'gpu' ] === 'undefined' ) {
log ( 'override: backend set to webgpu but browser does not support webgpu' ) ;
instance . config . backend = 'humangl' ;
} else {
const adapter = await navigator [ 'gpu' ] . requestAdapter ( ) ;
if ( instance . config . debug ) log ( 'enumerated webgpu adapter:' , adapter ) ;
2022-06-02 16:39:53 +02:00
if ( ! adapter ) {
log ( 'override: backend set to webgpu but browser reports no available gpu' ) ;
instance . config . backend = 'humangl' ;
} else {
// @ts-ignore requestAdapterInfo is not in tslib
// eslint-disable-next-line no-undef
const adapterInfo = 'requestAdapterInfo' in adapter ? await ( adapter as GPUAdapter ) . requestAdapterInfo ( ) : undefined ;
// if (adapter.features) adapter.features.forEach((feature) => log('webgpu features:', feature));
log ( 'webgpu adapter info:' , adapterInfo ) ;
}
2021-09-13 00:37:06 +02:00
}
}
// check available backends
2021-09-17 17:23:00 +02:00
if ( instance . config . backend === 'humangl' ) await humangl . register ( instance ) ;
2021-09-13 00:37:06 +02:00
const available = Object . keys ( tf . engine ( ) . registryFactory ) ;
if ( instance . config . debug ) log ( 'available backends:' , available ) ;
if ( ! available . includes ( instance . config . backend ) ) {
log ( ` error: backend ${ instance . config . backend } not found in registry ` ) ;
2021-10-21 16:26:44 +02:00
instance . config . backend = env . node ? 'tensorflow' : 'webgl' ;
2021-09-17 17:23:00 +02:00
if ( instance . config . debug ) log ( ` override: setting backend ${ instance . config . backend } ` ) ;
2021-09-13 00:37:06 +02:00
}
if ( instance . config . debug ) log ( 'setting backend:' , instance . config . backend ) ;
2021-10-07 16:33:10 +02:00
// customize wasm
2021-09-13 00:37:06 +02:00
if ( instance . config . backend === 'wasm' ) {
2022-08-04 15:15:13 +02:00
if ( tf . env ( ) . flagRegistry [ 'CANVAS2D_WILL_READ_FREQUENTLY' ] ) tf . env ( ) . set ( 'CANVAS2D_WILL_READ_FREQUENTLY' , true ) ;
2021-09-13 00:37:06 +02:00
if ( instance . config . debug ) log ( 'wasm path:' , instance . config . wasmPath ) ;
2022-02-10 21:35:32 +01:00
if ( typeof tf ? . setWasmPaths !== 'undefined' ) await tf . setWasmPaths ( instance . config . wasmPath , instance . config . wasmPlatformFetch ) ;
2021-11-14 17:22:52 +01:00
else throw new Error ( 'backend error: attempting to use wasm backend but wasm path is not set' ) ;
2021-09-13 00:37:06 +02:00
const simd = await tf . env ( ) . getAsync ( 'WASM_HAS_SIMD_SUPPORT' ) ;
const mt = await tf . env ( ) . getAsync ( 'WASM_HAS_MULTITHREAD_SUPPORT' ) ;
if ( instance . config . debug ) log ( ` wasm execution: ${ simd ? 'SIMD' : 'no SIMD' } ${ mt ? 'multithreaded' : 'singlethreaded' } ` ) ;
if ( instance . config . debug && ! simd ) log ( 'warning: wasm simd support is not enabled' ) ;
}
try {
await tf . setBackend ( instance . config . backend ) ;
await tf . ready ( ) ;
2021-11-17 02:16:49 +01:00
constants . init ( ) ;
2021-09-13 00:37:06 +02:00
} catch ( err ) {
log ( 'error: cannot set backend:' , instance . config . backend , err ) ;
2021-09-17 20:07:44 +02:00
return false ;
2021-09-13 00:37:06 +02:00
}
2021-01-13 15:35:31 +01:00
}
2021-09-13 00:37:06 +02:00
2021-10-07 16:33:10 +02:00
// customize humangl
2021-09-13 00:37:06 +02:00
if ( tf . getBackend ( ) === 'humangl' ) {
2022-08-04 15:15:13 +02:00
if ( tf . env ( ) . flagRegistry [ 'CHECK_COMPUTATION_FOR_ERRORS' ] ) tf . env ( ) . set ( 'CHECK_COMPUTATION_FOR_ERRORS' , false ) ;
if ( tf . env ( ) . flagRegistry [ 'WEBGL_CPU_FORWARD' ] ) tf . env ( ) . set ( 'WEBGL_CPU_FORWARD' , true ) ;
if ( tf . env ( ) . flagRegistry [ 'WEBGL_USE_SHAPES_UNIFORMS' ] ) tf . env ( ) . set ( 'WEBGL_USE_SHAPES_UNIFORMS' , true ) ;
if ( tf . env ( ) . flagRegistry [ 'CPU_HANDOFF_SIZE_THRESHOLD' ] ) tf . env ( ) . set ( 'CPU_HANDOFF_SIZE_THRESHOLD' , 256 ) ;
if ( tf . env ( ) . flagRegistry [ 'WEBGL_EXP_CONV' ] ) tf . env ( ) . set ( 'WEBGL_EXP_CONV' , true ) ; // <https://github.com/tensorflow/tfjs/issues/6678>
2022-08-10 19:44:38 +02:00
if ( tf . env ( ) . flagRegistry [ 'USE_SETTIMEOUTCUSTOM' ] ) tf . env ( ) . set ( 'USE_SETTIMEOUTCUSTOM' , true ) ; // <https://github.com/tensorflow/tfjs/issues/6687>
2022-08-04 15:15:13 +02:00
// if (tf.env().flagRegistry['WEBGL_PACK_DEPTHWISECONV']) tf.env().set('WEBGL_PACK_DEPTHWISECONV', false);
// if (if (tf.env().flagRegistry['WEBGL_FORCE_F16_TEXTURES']) && !instance.config.object.enabled) tf.env().set('WEBGL_FORCE_F16_TEXTURES', true); // safe to use 16bit precision
2021-09-13 00:37:06 +02:00
if ( typeof instance . config [ 'deallocate' ] !== 'undefined' && instance . config [ 'deallocate' ] ) { // hidden param
log ( 'changing webgl: WEBGL_DELETE_TEXTURE_THRESHOLD:' , true ) ;
2022-08-04 15:15:13 +02:00
tf . env ( ) . set ( 'WEBGL_DELETE_TEXTURE_THRESHOLD' , 0 ) ;
2021-09-13 00:37:06 +02:00
}
2021-10-03 14:12:26 +02:00
if ( tf . backend ( ) . getGPGPUContext ) {
const gl = await tf . backend ( ) . getGPGPUContext ( ) . gl ;
if ( instance . config . debug ) log ( ` gl version: ${ gl . getParameter ( gl . VERSION ) } renderer: ${ gl . getParameter ( gl . RENDERER ) } ` ) ;
}
}
2021-10-07 16:33:10 +02:00
// customize webgpu
2021-10-04 22:29:15 +02:00
if ( tf . getBackend ( ) === 'webgpu' ) {
2022-08-04 15:15:13 +02:00
// if (tf.env().flagRegistry['WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD']) tf.env().set('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD', 512);
// if (tf.env().flagRegistry['WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE']) tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', 0);
// if (tf.env().flagRegistry['WEBGPU_CPU_FORWARD']) tf.env().set('WEBGPU_CPU_FORWARD', true);
2021-01-13 15:35:31 +01:00
}
2021-09-13 00:37:06 +02:00
// wait for ready
tf . enableProdMode ( ) ;
await tf . ready ( ) ;
2021-11-05 18:36:53 +01:00
2021-10-27 15:45:38 +02:00
instance . performance . initBackend = Math . trunc ( now ( ) - timeStamp ) ;
2021-09-13 00:37:06 +02:00
instance . config . backend = tf . getBackend ( ) ;
2021-11-06 15:21:51 +01:00
await env . updateBackend ( ) ; // update env on backend init
2021-11-05 18:36:53 +01:00
registerCustomOps ( ) ;
2021-11-06 15:21:51 +01:00
// await env.updateBackend(); // update env on backend init
2020-12-13 00:34:30 +01:00
}
2021-09-17 20:07:44 +02:00
return true ;
2020-12-13 00:34:30 +01:00
}
2021-09-23 20:09:41 +02:00
// register fake missing tfjs ops
export function fakeOps ( kernelNames : Array < string > , config ) {
// if (config.debug) log('registerKernel:', kernelNames);
for ( const kernelName of kernelNames ) {
const kernelConfig = {
kernelName ,
backendName : config.backend ,
kernelFunc : ( ) = > { if ( config . debug ) log ( 'kernelFunc' , kernelName , config . backend ) ; } ,
// setupFunc: () => { if (config.debug) log('kernelFunc', kernelName, config.backend); },
// disposeFunc: () => { if (config.debug) log('kernelFunc', kernelName, config.backend); },
} ;
tf . registerKernel ( kernelConfig ) ;
}
2021-10-21 16:26:44 +02:00
env . kernels = tf . getKernelsForBackend ( tf . getBackend ( ) ) . map ( ( kernel ) = > kernel . kernelName . toLowerCase ( ) ) ; // re-scan registered ops
2021-09-23 20:09:41 +02:00
}