From c5911301e9e4ef07a2071a73f526f54b19d91f71 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sun, 16 Jan 2022 09:49:55 -0500 Subject: [PATCH] prototype global fetch handler --- demo/nodejs/node-fetch.js | 25 ++++++++++++++++++++++ src/body/blazepose.ts | 5 +++-- src/body/efficientpose.ts | 3 ++- src/body/movenet.ts | 3 ++- src/body/posenet.ts | 3 ++- src/face/antispoof.ts | 3 ++- src/face/blazeface.ts | 3 ++- src/face/facemesh.ts | 3 ++- src/face/faceres.ts | 3 ++- src/face/iris.ts | 3 ++- src/face/liveness.ts | 2 +- src/face/mobilefacenet.ts | 3 ++- src/gear/emotion.ts | 3 ++- src/gear/gear.ts | 3 ++- src/gear/ssrnet-age.ts | 3 ++- src/gear/ssrnet-gender.ts | 3 ++- src/hand/handpose.ts | 6 +++--- src/hand/handtrack.ts | 5 +++-- src/object/centernet.ts | 3 ++- src/object/nanodet.ts | 3 ++- src/tfjs/load.ts | 45 +++++++++++++++++++++++++++++++++++++++ 21 files changed, 110 insertions(+), 23 deletions(-) create mode 100644 demo/nodejs/node-fetch.js create mode 100644 src/tfjs/load.ts diff --git a/demo/nodejs/node-fetch.js b/demo/nodejs/node-fetch.js new file mode 100644 index 00000000..9c78d92f --- /dev/null +++ b/demo/nodejs/node-fetch.js @@ -0,0 +1,25 @@ +const fs = require('fs'); + +// eslint-disable-next-line import/no-extraneous-dependencies, no-unused-vars, @typescript-eslint/no-unused-vars +const tf = require('@tensorflow/tfjs-node'); // in nodejs environments tfjs-node is required to be loaded before human +// const faceapi = require('@vladmandic/face-api'); // use this when human is installed as module (majority of use cases) +const Human = require('../../dist/human.node.js'); // use this when using human in dev mode + +const humanConfig = { + modelBasePath: 'https://vladmandic.github.io/human/models/', +}; + +async function main(inputFile) { + // @ts-ignore + global.fetch = (await import('node-fetch')).default; + const human = new Human.Human(humanConfig); // create instance of human using default configuration + await human.load(); // optional as models would be loaded on-demand first time they are required + await human.warmup(); // optional as model warmup is performed on-demand first time its executed + const buffer = fs.readFileSync(inputFile); // read file data into buffer + const tensor = human.tf.node.decodeImage(buffer); // decode jpg data + const result = await human.detect(tensor); // run detection; will initialize backend and on-demand load models + // eslint-disable-next-line no-console + console.log(result.gesture); +} + +main('samples/in/ai-body.jpg'); diff --git a/src/body/blazepose.ts b/src/body/blazepose.ts index 399213e2..390e3e3a 100644 --- a/src/body/blazepose.ts +++ b/src/body/blazepose.ts @@ -3,6 +3,7 @@ */ import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import { constants } from '../tfjs/constants'; import { log, join, now } from '../util/util'; import type { BodyKeypoint, BodyResult, BodyLandmark, Box, Point, BodyAnnotation } from '../result'; @@ -32,7 +33,7 @@ const sigmoid = (x) => (1 - (1 / (1 + Math.exp(x)))); export async function loadDetect(config: Config): Promise { if (env.initial) models.detector = null; if (!models.detector && config.body['detector'] && config.body['detector']['modelPath'] || '') { - models.detector = await tf.loadGraphModel(join(config.modelBasePath, config.body['detector']['modelPath'] || '')) as unknown as GraphModel; + models.detector = await loadModel(join(config.modelBasePath, config.body['detector']['modelPath'] || '')) as unknown as GraphModel; const inputs = Object.values(models.detector.modelSignature['inputs']); inputSize.detector[0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0; inputSize.detector[1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; @@ -46,7 +47,7 @@ export async function loadDetect(config: Config): Promise { export async function loadPose(config: Config): Promise { if (env.initial) models.landmarks = null; if (!models.landmarks) { - models.landmarks = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; + models.landmarks = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; const inputs = Object.values(models.landmarks.modelSignature['inputs']); inputSize.landmarks[0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0; inputSize.landmarks[1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; diff --git a/src/body/efficientpose.ts b/src/body/efficientpose.ts index 2ddd946e..cd7f8cb7 100644 --- a/src/body/efficientpose.ts +++ b/src/body/efficientpose.ts @@ -6,6 +6,7 @@ import { log, join, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import * as coords from './efficientposecoords'; import { constants } from '../tfjs/constants'; import type { BodyResult, Point, BodyLandmark, BodyAnnotation } from '../result'; @@ -26,7 +27,7 @@ let skipped = Number.MAX_SAFE_INTEGER; export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/body/movenet.ts b/src/body/movenet.ts index e8ab67cf..95879522 100644 --- a/src/body/movenet.ts +++ b/src/body/movenet.ts @@ -9,6 +9,7 @@ import * as box from '../util/box'; import * as tf from '../../dist/tfjs.esm.js'; import * as coords from './movenetcoords'; import * as fix from './movenetfix'; +import { loadModel } from '../tfjs/load'; import type { BodyKeypoint, BodyResult, BodyLandmark, BodyAnnotation, Box, Point } from '../result'; import type { GraphModel, Tensor } from '../tfjs/types'; import type { Config } from '../config'; @@ -34,7 +35,7 @@ export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { fakeOps(['size'], config); - model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/body/posenet.ts b/src/body/posenet.ts index 1b30a042..0945af06 100644 --- a/src/body/posenet.ts +++ b/src/body/posenet.ts @@ -6,6 +6,7 @@ import { log, join } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import type { BodyResult, BodyLandmark, Box } from '../result'; import type { Tensor, GraphModel } from '../tfjs/types'; import type { Config } from '../config'; @@ -179,7 +180,7 @@ export async function predict(input: Tensor, config: Config): Promise { if (!model || env.initial) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/face/antispoof.ts b/src/face/antispoof.ts index 658392b0..c7cd945a 100644 --- a/src/face/antispoof.ts +++ b/src/face/antispoof.ts @@ -6,6 +6,7 @@ import { log, join, now } from '../util/util'; import type { Config } from '../config'; import type { GraphModel, Tensor } from '../tfjs/types'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import { env } from '../util/env'; let model: GraphModel | null; @@ -17,7 +18,7 @@ let lastTime = 0; export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face.antispoof?.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face.antispoof?.modelPath || '')) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.face.antispoof?.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/face/blazeface.ts b/src/face/blazeface.ts index 718d7606..f3ccff48 100644 --- a/src/face/blazeface.ts +++ b/src/face/blazeface.ts @@ -6,6 +6,7 @@ import { log, join } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import * as util from './facemeshutil'; +import { loadModel } from '../tfjs/load'; import { constants } from '../tfjs/constants'; import type { Config } from '../config'; import type { Tensor, GraphModel } from '../tfjs/types'; @@ -26,7 +27,7 @@ export const size = () => inputSize; export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face.detector?.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face.detector?.modelPath || '')) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.face.detector?.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/face/facemesh.ts b/src/face/facemesh.ts index e37c6cd1..f55091af 100644 --- a/src/face/facemesh.ts +++ b/src/face/facemesh.ts @@ -8,6 +8,7 @@ */ import { log, join, now } from '../util/util'; +import { loadModel } from '../tfjs/load'; import * as tf from '../../dist/tfjs.esm.js'; import * as blazeface from './blazeface'; import * as util from './facemeshutil'; @@ -111,7 +112,7 @@ export async function predict(input: Tensor, config: Config): Promise { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face.mesh?.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face.mesh?.modelPath || '')) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.face.mesh?.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/face/faceres.ts b/src/face/faceres.ts index fccf81a4..80753174 100644 --- a/src/face/faceres.ts +++ b/src/face/faceres.ts @@ -10,6 +10,7 @@ import { log, join, now } from '../util/util'; import { env } from '../util/env'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import { constants } from '../tfjs/constants'; import type { Tensor, GraphModel } from '../tfjs/types'; import type { Config } from '../config'; @@ -33,7 +34,7 @@ export async function load(config: Config): Promise { const modelUrl = join(config.modelBasePath, config.face.description?.modelPath || ''); if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(modelUrl) as unknown as GraphModel; + model = await loadModel(modelUrl) as unknown as GraphModel; if (!model) log('load model failed:', config.face.description?.modelPath || ''); else if (config.debug) log('load model:', modelUrl); } else if (config.debug) log('cached model:', modelUrl); diff --git a/src/face/iris.ts b/src/face/iris.ts index 31e981f6..7bdec347 100644 --- a/src/face/iris.ts +++ b/src/face/iris.ts @@ -4,6 +4,7 @@ import * as tf from '../../dist/tfjs.esm.js'; import type { Tensor, GraphModel } from '../tfjs/types'; import { env } from '../util/env'; import { log, join } from '../util/util'; +import { loadModel } from '../tfjs/load'; import type { Config } from '../config'; import type { Point } from '../result'; @@ -30,7 +31,7 @@ const irisLandmarks = { export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face.iris?.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face.iris?.modelPath || '')); if (!model || !model['modelUrl']) log('load model failed:', config.face.iris?.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/face/liveness.ts b/src/face/liveness.ts index 49ba3be9..d558c708 100644 --- a/src/face/liveness.ts +++ b/src/face/liveness.ts @@ -17,7 +17,7 @@ let lastTime = 0; export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face.liveness?.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face.liveness?.modelPath || '')) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.face.liveness?.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/face/mobilefacenet.ts b/src/face/mobilefacenet.ts index 2f460bce..068baf15 100644 --- a/src/face/mobilefacenet.ts +++ b/src/face/mobilefacenet.ts @@ -8,6 +8,7 @@ import { log, join, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import type { Tensor, GraphModel } from '../tfjs/types'; import type { Config } from '../config'; import { env } from '../util/env'; @@ -22,7 +23,7 @@ export async function load(config: Config): Promise { const modelUrl = join(config.modelBasePath, config.face['mobilefacenet'].modelPath); if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(modelUrl) as unknown as GraphModel; + model = await loadModel(modelUrl) as unknown as GraphModel; if (!model) log('load model failed:', config.face['mobilefacenet'].modelPath); else if (config.debug) log('load model:', modelUrl); } else if (config.debug) log('cached model:', modelUrl); diff --git a/src/gear/emotion.ts b/src/gear/emotion.ts index 4992b545..2219db6a 100644 --- a/src/gear/emotion.ts +++ b/src/gear/emotion.ts @@ -9,6 +9,7 @@ import { log, join, now } from '../util/util'; import type { Config } from '../config'; import type { GraphModel, Tensor } from '../tfjs/types'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import { env } from '../util/env'; import { constants } from '../tfjs/constants'; @@ -22,7 +23,7 @@ let skipped = Number.MAX_SAFE_INTEGER; export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face.emotion?.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face.emotion?.modelPath || '')) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.face.emotion?.modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/gear/gear.ts b/src/gear/gear.ts index 8bc36f29..11396f3e 100644 --- a/src/gear/gear.ts +++ b/src/gear/gear.ts @@ -6,6 +6,7 @@ import { log, join, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import type { Gender, Race } from '../result'; import type { Config } from '../config'; import type { GraphModel, Tensor } from '../tfjs/types'; @@ -24,7 +25,7 @@ let skipped = Number.MAX_SAFE_INTEGER; export async function load(config: Config) { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face['gear'].modelPath)) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face['gear'].modelPath)) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.face['gear'].modelPath); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/gear/ssrnet-age.ts b/src/gear/ssrnet-age.ts index 96bb998f..3b675b5d 100644 --- a/src/gear/ssrnet-age.ts +++ b/src/gear/ssrnet-age.ts @@ -6,6 +6,7 @@ import { log, join, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import { env } from '../util/env'; import { constants } from '../tfjs/constants'; import type { Config } from '../config'; @@ -21,7 +22,7 @@ let skipped = Number.MAX_SAFE_INTEGER; export async function load(config: Config) { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face['ssrnet'].modelPathAge)) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face['ssrnet'].modelPathAge)) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.face['ssrnet'].modelPathAge); else if (config.debug) log('load model:', model['modelUrl']); } else { diff --git a/src/gear/ssrnet-gender.ts b/src/gear/ssrnet-gender.ts index 7cd38fb8..e7c1c82b 100644 --- a/src/gear/ssrnet-gender.ts +++ b/src/gear/ssrnet-gender.ts @@ -6,6 +6,7 @@ import { log, join, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import { constants } from '../tfjs/constants'; import type { Gender } from '../result'; import type { Config } from '../config'; @@ -25,7 +26,7 @@ const rgb = [0.2989, 0.5870, 0.1140]; // factors for red/green/blue colors when export async function load(config: Config | any) { if (env.initial) model = null; if (!model) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.face['ssrnet'].modelPathGender)) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.face['ssrnet'].modelPathGender)) as unknown as GraphModel; if (!model || !model['modelUrl']) log('load model failed:', config.face['ssrnet'].modelPathGender); else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); diff --git a/src/hand/handpose.ts b/src/hand/handpose.ts index 6a6f94b0..719c1f7c 100644 --- a/src/hand/handpose.ts +++ b/src/hand/handpose.ts @@ -5,10 +5,10 @@ */ import { log, join } from '../util/util'; -import * as tf from '../../dist/tfjs.esm.js'; import * as handdetector from './handposedetector'; import * as handpipeline from './handposepipeline'; import * as fingerPose from './fingerpose'; +import { loadModel } from '../tfjs/load'; import type { HandResult, Box, Point } from '../result'; import type { Tensor, GraphModel } from '../tfjs/types'; import type { Config } from '../config'; @@ -89,8 +89,8 @@ export async function load(config: Config): Promise<[GraphModel | null, GraphMod } if (!handDetectorModel || !handPoseModel) { [handDetectorModel, handPoseModel] = await Promise.all([ - config.hand.enabled ? tf.loadGraphModel(join(config.modelBasePath, config.hand.detector?.modelPath || ''), { fromTFHub: (config.hand.detector?.modelPath || '').includes('tfhub.dev') }) as unknown as GraphModel : null, - config.hand.landmarks ? tf.loadGraphModel(join(config.modelBasePath, config.hand.skeleton?.modelPath || ''), { fromTFHub: (config.hand.skeleton?.modelPath || '').includes('tfhub.dev') }) as unknown as GraphModel : null, + config.hand.enabled ? loadModel(join(config.modelBasePath, config.hand.detector?.modelPath || '')) as unknown as GraphModel : null, + config.hand.landmarks ? loadModel(join(config.modelBasePath, config.hand.skeleton?.modelPath || '')) as unknown as GraphModel : null, ]); if (config.hand.enabled) { if (!handDetectorModel || !handDetectorModel['modelUrl']) log('load model failed:', config.hand.detector?.modelPath || ''); diff --git a/src/hand/handtrack.ts b/src/hand/handtrack.ts index f6621e66..0734d4d2 100644 --- a/src/hand/handtrack.ts +++ b/src/hand/handtrack.ts @@ -9,6 +9,7 @@ import { log, join, now } from '../util/util'; import * as box from '../util/box'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import type { HandResult, HandType, Box, Point } from '../result'; import type { GraphModel, Tensor } from '../tfjs/types'; import type { Config } from '../config'; @@ -74,7 +75,7 @@ export async function loadDetect(config: Config): Promise { // handtrack model has some kernel ops defined in model but those are never referenced and non-existent in tfjs // ideally need to prune the model itself fakeOps(['tensorlistreserve', 'enter', 'tensorlistfromtensor', 'merge', 'loopcond', 'switch', 'exit', 'tensorliststack', 'nextiteration', 'tensorlistsetitem', 'tensorlistgetitem', 'reciprocal', 'shape', 'split', 'where'], config); - models[0] = await tf.loadGraphModel(join(config.modelBasePath, config.hand.detector?.modelPath || '')) as unknown as GraphModel; + models[0] = await loadModel(join(config.modelBasePath, config.hand.detector?.modelPath || '')) as unknown as GraphModel; const inputs = Object.values(models[0].modelSignature['inputs']); inputSize[0][0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0; inputSize[0][1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; @@ -87,7 +88,7 @@ export async function loadDetect(config: Config): Promise { export async function loadSkeleton(config: Config): Promise { if (env.initial) models[1] = null; if (!models[1]) { - models[1] = await tf.loadGraphModel(join(config.modelBasePath, config.hand.skeleton?.modelPath || '')) as unknown as GraphModel; + models[1] = await loadModel(join(config.modelBasePath, config.hand.skeleton?.modelPath || '')) as unknown as GraphModel; const inputs = Object.values(models[1].modelSignature['inputs']); inputSize[1][0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0; inputSize[1][1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; diff --git a/src/object/centernet.ts b/src/object/centernet.ts index c0858bd8..15497bef 100644 --- a/src/object/centernet.ts +++ b/src/object/centernet.ts @@ -6,6 +6,7 @@ import { log, join, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import { labels } from './labels'; import type { ObjectResult, ObjectType, Box } from '../result'; import type { GraphModel, Tensor } from '../tfjs/types'; @@ -22,7 +23,7 @@ export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { // fakeOps(['floormod'], config); - model = await tf.loadGraphModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel; const inputs = Object.values(model.modelSignature['inputs']); inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; if (!model || !model['modelUrl']) log('load model failed:', config.object.modelPath); diff --git a/src/object/nanodet.ts b/src/object/nanodet.ts index 6882b5c9..b6659eb9 100644 --- a/src/object/nanodet.ts +++ b/src/object/nanodet.ts @@ -6,6 +6,7 @@ import { log, join, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; +import { loadModel } from '../tfjs/load'; import { constants } from '../tfjs/constants'; import { labels } from './labels'; import type { ObjectResult, ObjectType, Box } from '../result'; @@ -23,7 +24,7 @@ const scaleBox = 2.5; // increase box size export async function load(config: Config): Promise { if (!model || env.initial) { - model = await tf.loadGraphModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel; + model = await loadModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel; const inputs = Object.values(model.modelSignature['inputs']); inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; if (!model || !model['modelUrl']) log('load model failed:', config.object.modelPath); diff --git a/src/tfjs/load.ts b/src/tfjs/load.ts new file mode 100644 index 00000000..c08f0405 --- /dev/null +++ b/src/tfjs/load.ts @@ -0,0 +1,45 @@ +import { log, mergeDeep } from '../util/util'; +import * as tf from '../../dist/tfjs.esm.js'; +import type { GraphModel } from './types'; + +type FetchFunc = (url: RequestInfo, init?: RequestInit) => Promise; +type ProgressFunc = (...args) => void; + +export type LoadOptions = { + appName: string, + autoSave: boolean, + verbose: boolean, + fetchFunc?: FetchFunc, + onProgress?: ProgressFunc, +} + +let options: LoadOptions = { + appName: 'human', + autoSave: true, + verbose: true, +}; + +async function httpHandler(url: RequestInfo, init?: RequestInit): Promise { + if (options.fetchFunc) return options.fetchFunc(url, init); + else log('error: fetch function is not defined'); + return null; +} + +const tfLoadOptions = { + onProgress: (...args) => { + if (options.onProgress) options.onProgress(...args); + else if (options.verbose) log('load model progress:', ...args); + }, + fetchFunc: (url: RequestInfo, init?: RequestInit) => { + if (options.verbose) log('load model fetch:', url, init); + if (url.toString().toLowerCase().startsWith('http')) return httpHandler(url, init); + return null; + }, +}; + +export async function loadModel(modelUrl: string, loadOptions?: LoadOptions): Promise { + if (loadOptions) options = mergeDeep(loadOptions); + if (!options.fetchFunc && (typeof globalThis.fetch !== 'undefined')) options.fetchFunc = globalThis.fetch; + const model = await tf.loadGraphModel(modelUrl, tfLoadOptions) as unknown as GraphModel; + return model; +}