2021-09-25 17:51:15 +02:00
/** TFJS backend initialization and customization */
2022-08-30 16:28:33 +02:00
import type { Human , Config } from '../human' ;
2021-09-27 19:58:13 +02:00
import { log , now } from '../util/util' ;
2021-10-21 16:26:44 +02:00
import { env } from '../util/env' ;
2021-09-13 00:37:06 +02:00
import * as humangl from './humangl' ;
2020-12-13 00:34:30 +01:00
import * as tf from '../../dist/tfjs.esm.js' ;
2021-11-17 02:16:49 +01:00
import * as constants from './constants' ;
2020-12-13 00:34:30 +01:00
2022-08-30 16:28:33 +02:00
function registerCustomOps ( config : Config ) {
2022-09-02 20:07:10 +02:00
const newKernels : string [ ] = [ ] ;
2021-11-05 18:36:53 +01:00
if ( ! env . kernels . includes ( 'mod' ) ) {
const kernelMod = {
kernelName : 'Mod' ,
backendName : tf.getBackend ( ) ,
kernelFunc : ( op ) = > tf . tidy ( ( ) = > tf . sub ( op . inputs . a , tf . mul ( tf . div ( op . inputs . a , op . inputs . b ) , op . inputs . b ) ) ) ,
} ;
tf . registerKernel ( kernelMod ) ;
2021-11-06 15:21:51 +01:00
env . kernels . push ( 'mod' ) ;
2022-09-02 20:07:10 +02:00
newKernels . push ( 'mod' ) ;
2021-11-05 18:36:53 +01:00
}
if ( ! env . kernels . includes ( 'floormod' ) ) {
2022-08-30 16:28:33 +02:00
const kernelFloorMod = {
2021-11-05 18:36:53 +01:00
kernelName : 'FloorMod' ,
backendName : tf.getBackend ( ) ,
2022-08-21 21:23:03 +02:00
kernelFunc : ( op ) = > tf . tidy ( ( ) = > tf . add ( tf . mul ( tf . floorDiv ( op . inputs . a / op . inputs . b ) , op . inputs . b ) , tf . mod ( op . inputs . a , op . inputs . b ) ) ) ,
2021-11-05 18:36:53 +01:00
} ;
2022-08-30 16:28:33 +02:00
tf . registerKernel ( kernelFloorMod ) ;
2021-11-06 15:21:51 +01:00
env . kernels . push ( 'floormod' ) ;
2022-09-02 20:07:10 +02:00
newKernels . push ( 'floormod' ) ;
2021-11-05 18:36:53 +01:00
}
2022-08-30 16:28:33 +02:00
/ *
if ( ! env . kernels . includes ( 'atan2' ) && config . softwareKernels ) {
const kernelAtan2 = {
kernelName : 'Atan2' ,
backendName : tf.getBackend ( ) ,
kernelFunc : ( op ) = > tf . tidy ( ( ) = > {
const backend = tf . getBackend ( ) ;
tf . setBackend ( 'cpu' ) ;
const t = tf . atan2 ( op . inputs . a , op . inputs . b ) ;
tf . setBackend ( backend ) ;
return t ;
} ) ,
} ;
if ( config . debug ) log ( 'registered kernel:' , 'atan2' ) ;
log ( 'registered kernel:' , 'atan2' ) ;
tf . registerKernel ( kernelAtan2 ) ;
env . kernels . push ( 'atan2' ) ;
2022-09-02 20:07:10 +02:00
newKernels . push ( 'atan2' ) ;
2022-08-30 16:28:33 +02:00
}
* /
if ( ! env . kernels . includes ( 'rotatewithoffset' ) && config . softwareKernels ) {
const kernelRotateWithOffset = {
kernelName : 'RotateWithOffset' ,
backendName : tf.getBackend ( ) ,
kernelFunc : ( op ) = > tf . tidy ( ( ) = > {
const backend = tf . getBackend ( ) ;
tf . setBackend ( 'cpu' ) ;
const t = tf . image . rotateWithOffset ( op . inputs . image , op . attrs . radians , op . attrs . fillValue , op . attrs . center ) ;
tf . setBackend ( backend ) ;
return t ;
} ) ,
} ;
tf . registerKernel ( kernelRotateWithOffset ) ;
env . kernels . push ( 'rotatewithoffset' ) ;
2022-09-02 20:07:10 +02:00
newKernels . push ( 'rotatewithoffset' ) ;
2022-08-30 16:28:33 +02:00
}
2022-09-02 20:07:10 +02:00
if ( ( newKernels . length > 0 ) && config . debug ) log ( 'registered kernels:' , newKernels ) ;
2021-11-05 18:36:53 +01:00
}
2022-09-02 20:07:10 +02:00
let defaultFlags : Record < string , unknown > = { } ;
2021-11-12 21:07:23 +01:00
export async function check ( instance : Human , force = false ) {
2021-09-21 03:59:49 +02:00
instance . state = 'backend' ;
2022-09-02 20:07:10 +02:00
if ( force || env . initial || ( instance . config . backend && ( instance . config . backend . length > 0 ) && ( tf . getBackend ( ) !== instance . config . backend ) ) ) {
2021-09-13 00:37:06 +02:00
const timeStamp = now ( ) ;
2020-12-13 00:34:30 +01:00
2021-09-13 00:37:06 +02:00
if ( instance . config . backend && instance . config . backend . length > 0 ) {
// detect web worker
// @ts-ignore ignore missing type for WorkerGlobalScope as that is the point
if ( typeof window === 'undefined' && typeof WorkerGlobalScope !== 'undefined' && instance . config . debug ) {
2021-09-17 17:23:00 +02:00
if ( instance . config . debug ) log ( 'running inside web worker' ) ;
2021-09-13 00:37:06 +02:00
}
2021-06-11 22:12:24 +02:00
2021-09-13 00:37:06 +02:00
// force browser vs node backend
2021-10-21 16:26:44 +02:00
if ( env . browser && instance . config . backend === 'tensorflow' ) {
2021-09-17 17:23:00 +02:00
if ( instance . config . debug ) log ( 'override: backend set to tensorflow while running in browser' ) ;
2022-09-03 13:13:08 +02:00
instance . config . backend = 'webgl' ;
2021-09-13 00:37:06 +02:00
}
2021-10-21 16:26:44 +02:00
if ( env . node && ( instance . config . backend === 'webgl' || instance . config . backend === 'humangl' ) ) {
2021-09-17 17:23:00 +02:00
if ( instance . config . debug ) log ( ` override: backend set to ${ instance . config . backend } while running in nodejs ` ) ;
2021-09-13 00:37:06 +02:00
instance . config . backend = 'tensorflow' ;
}
// handle webgpu
2021-10-21 16:26:44 +02:00
if ( env . browser && instance . config . backend === 'webgpu' ) {
2022-08-21 19:34:51 +02:00
if ( typeof navigator === 'undefined' || typeof navigator . gpu === 'undefined' ) {
2021-09-13 00:37:06 +02:00
log ( 'override: backend set to webgpu but browser does not support webgpu' ) ;
2022-09-03 13:13:08 +02:00
instance . config . backend = 'webgl' ;
2021-09-13 00:37:06 +02:00
} else {
2022-08-21 19:34:51 +02:00
const adapter = await navigator . gpu . requestAdapter ( ) ;
2021-09-13 00:37:06 +02:00
if ( instance . config . debug ) log ( 'enumerated webgpu adapter:' , adapter ) ;
2022-06-02 16:39:53 +02:00
if ( ! adapter ) {
log ( 'override: backend set to webgpu but browser reports no available gpu' ) ;
2022-09-03 13:13:08 +02:00
instance . config . backend = 'webgl' ;
2022-06-02 16:39:53 +02:00
} else {
// @ts-ignore requestAdapterInfo is not in tslib
const adapterInfo = 'requestAdapterInfo' in adapter ? await ( adapter as GPUAdapter ) . requestAdapterInfo ( ) : undefined ;
// if (adapter.features) adapter.features.forEach((feature) => log('webgpu features:', feature));
log ( 'webgpu adapter info:' , adapterInfo ) ;
}
2021-09-13 00:37:06 +02:00
}
}
// check available backends
2022-09-02 17:57:47 +02:00
let available = Object . keys ( tf . engine ( ) . registryFactory as Record < string , unknown > ) ;
if ( instance . config . backend === 'humangl' && ! available . includes ( 'humangl' ) ) {
humangl . register ( instance ) ;
available = Object . keys ( tf . engine ( ) . registryFactory as Record < string , unknown > ) ;
}
2021-09-13 00:37:06 +02:00
if ( instance . config . debug ) log ( 'available backends:' , available ) ;
if ( ! available . includes ( instance . config . backend ) ) {
log ( ` error: backend ${ instance . config . backend } not found in registry ` ) ;
2021-10-21 16:26:44 +02:00
instance . config . backend = env . node ? 'tensorflow' : 'webgl' ;
2021-09-17 17:23:00 +02:00
if ( instance . config . debug ) log ( ` override: setting backend ${ instance . config . backend } ` ) ;
2021-09-13 00:37:06 +02:00
}
2022-09-02 20:07:10 +02:00
if ( instance . config . debug ) log ( 'setting backend:' , [ instance . config . backend ] ) ;
2021-09-13 00:37:06 +02:00
2021-10-07 16:33:10 +02:00
// customize wasm
2021-09-13 00:37:06 +02:00
if ( instance . config . backend === 'wasm' ) {
2022-08-21 19:34:51 +02:00
if ( tf . env ( ) . flagRegistry . CANVAS2D_WILL_READ_FREQUENTLY ) tf . env ( ) . set ( 'CANVAS2D_WILL_READ_FREQUENTLY' , true ) ;
2021-09-13 00:37:06 +02:00
if ( instance . config . debug ) log ( 'wasm path:' , instance . config . wasmPath ) ;
2022-08-21 21:23:03 +02:00
if ( typeof tf . setWasmPaths !== 'undefined' ) tf . setWasmPaths ( instance . config . wasmPath , instance . config . wasmPlatformFetch ) ;
2021-11-14 17:22:52 +01:00
else throw new Error ( 'backend error: attempting to use wasm backend but wasm path is not set' ) ;
2022-08-15 17:29:56 +02:00
let mt = false ;
let simd = false ;
try {
mt = await tf . env ( ) . getAsync ( 'WASM_HAS_MULTITHREAD_SUPPORT' ) ;
simd = await tf . env ( ) . getAsync ( 'WASM_HAS_SIMD_SUPPORT' ) ;
if ( instance . config . debug ) log ( ` wasm execution: ${ simd ? 'simd' : 'no simd' } ${ mt ? 'multithreaded' : 'singlethreaded' } ` ) ;
if ( instance . config . debug && ! simd ) log ( 'warning: wasm simd support is not enabled' ) ;
} catch {
log ( 'wasm detection failed' ) ;
}
2021-09-13 00:37:06 +02:00
}
try {
await tf . setBackend ( instance . config . backend ) ;
await tf . ready ( ) ;
} catch ( err ) {
log ( 'error: cannot set backend:' , instance . config . backend , err ) ;
2021-09-17 20:07:44 +02:00
return false ;
2021-09-13 00:37:06 +02:00
}
2022-09-02 20:07:10 +02:00
if ( instance . config . debug ) defaultFlags = JSON . parse ( JSON . stringify ( tf . env ( ) . flags ) ) ;
2021-01-13 15:35:31 +01:00
}
2021-09-13 00:37:06 +02:00
2021-10-07 16:33:10 +02:00
// customize humangl
2022-09-03 13:13:08 +02:00
if ( tf . getBackend ( ) === 'humangl' || tf . getBackend ( ) === 'webgl' ) {
2022-09-02 16:22:24 +02:00
if ( tf . env ( ) . flagRegistry . WEBGL_USE_SHAPES_UNIFORMS ) tf . env ( ) . set ( 'WEBGL_USE_SHAPES_UNIFORMS' , true ) ; // default=false <https://github.com/tensorflow/tfjs/issues/5205>
if ( tf . env ( ) . flagRegistry . WEBGL_EXP_CONV ) tf . env ( ) . set ( 'WEBGL_EXP_CONV' , true ) ; // default=false <https://github.com/tensorflow/tfjs/issues/6678>
// if (tf.env().flagRegistry['WEBGL_PACK_DEPTHWISECONV']) tf.env().set('WEBGL_PACK_DEPTHWISECONV', false); // default=true <https://github.com/tensorflow/tfjs/pull/4909>
// if (tf.env().flagRegistry.USE_SETTIMEOUTCUSTOM) tf.env().set('USE_SETTIMEOUTCUSTOM', true); // default=false <https://github.com/tensorflow/tfjs/issues/6687>
// if (tf.env().flagRegistry.CPU_HANDOFF_SIZE_THRESHOLD) tf.env().set('CPU_HANDOFF_SIZE_THRESHOLD', 1024); // default=1000
// if (tf.env().flagRegistry['WEBGL_FORCE_F16_TEXTURES'] && !instance.config.object.enabled) tf.env().set('WEBGL_FORCE_F16_TEXTURES', true); // safe to use 16bit precision
2022-08-21 19:34:51 +02:00
if ( typeof instance . config . deallocate !== 'undefined' && instance . config . deallocate ) { // hidden param
2021-09-13 00:37:06 +02:00
log ( 'changing webgl: WEBGL_DELETE_TEXTURE_THRESHOLD:' , true ) ;
2022-08-04 15:15:13 +02:00
tf . env ( ) . set ( 'WEBGL_DELETE_TEXTURE_THRESHOLD' , 0 ) ;
2021-09-13 00:37:06 +02:00
}
2021-10-03 14:12:26 +02:00
}
2021-10-07 16:33:10 +02:00
// customize webgpu
2021-10-04 22:29:15 +02:00
if ( tf . getBackend ( ) === 'webgpu' ) {
2022-08-04 15:15:13 +02:00
// if (tf.env().flagRegistry['WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD']) tf.env().set('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD', 512);
// if (tf.env().flagRegistry['WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE']) tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', 0);
// if (tf.env().flagRegistry['WEBGPU_CPU_FORWARD']) tf.env().set('WEBGPU_CPU_FORWARD', true);
2021-01-13 15:35:31 +01:00
}
2021-09-13 00:37:06 +02:00
2022-09-02 20:07:10 +02:00
if ( instance . config . debug ) {
const newFlags = tf . env ( ) . flags ;
const updatedFlags = { } ;
for ( const key of Object . keys ( newFlags ) ) {
if ( defaultFlags [ key ] === newFlags [ key ] ) continue ;
updatedFlags [ key ] = newFlags [ key ] ;
}
2022-09-03 13:13:08 +02:00
if ( Object . keys ( updatedFlags ) . length > 0 ) log ( 'backend:' , tf . getBackend ( ) , 'flags:' , updatedFlags ) ;
2022-09-02 20:07:10 +02:00
}
2021-11-05 18:36:53 +01:00
2022-09-02 20:07:10 +02:00
tf . enableProdMode ( ) ;
constants . init ( ) ;
2021-10-27 15:45:38 +02:00
instance . performance . initBackend = Math . trunc ( now ( ) - timeStamp ) ;
2021-09-13 00:37:06 +02:00
instance . config . backend = tf . getBackend ( ) ;
2021-11-06 15:21:51 +01:00
await env . updateBackend ( ) ; // update env on backend init
2022-08-30 16:28:33 +02:00
registerCustomOps ( instance . config ) ;
2021-11-06 15:21:51 +01:00
// await env.updateBackend(); // update env on backend init
2022-09-02 20:07:10 +02:00
env . initial = false ;
2020-12-13 00:34:30 +01:00
}
2021-09-17 20:07:44 +02:00
return true ;
2020-12-13 00:34:30 +01:00
}
2021-09-23 20:09:41 +02:00
// register fake missing tfjs ops
2022-08-21 19:34:51 +02:00
export function fakeOps ( kernelNames : string [ ] , config ) {
2021-09-23 20:09:41 +02:00
// if (config.debug) log('registerKernel:', kernelNames);
for ( const kernelName of kernelNames ) {
const kernelConfig = {
kernelName ,
backendName : config.backend ,
kernelFunc : ( ) = > { if ( config . debug ) log ( 'kernelFunc' , kernelName , config . backend ) ; } ,
// setupFunc: () => { if (config.debug) log('kernelFunc', kernelName, config.backend); },
// disposeFunc: () => { if (config.debug) log('kernelFunc', kernelName, config.backend); },
} ;
tf . registerKernel ( kernelConfig ) ;
}
2022-08-21 21:23:03 +02:00
env . kernels = tf . getKernelsForBackend ( tf . getBackend ( ) ) . map ( ( kernel ) = > ( kernel . kernelName as string ) . toLowerCase ( ) ) ; // re-scan registered ops
2021-09-23 20:09:41 +02:00
}