From 2c0057cd300348af873df48a226e4a4bda25b488 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 17 Jan 2022 11:03:21 -0500 Subject: [PATCH] implement model caching using indexdb --- CHANGELOG.md | 5 ++- TODO.md | 18 ++++++-- demo/typescript/index.js | 2 +- demo/typescript/index.ts | 2 +- package.json | 8 ++-- src/body/blazepose.ts | 10 ++--- src/body/blazeposedetector.ts | 24 ----------- src/body/efficientpose.ts | 9 ++-- src/body/movenet.ts | 6 +-- src/body/posenet.ts | 9 ++-- src/config.ts | 6 +++ src/face/antispoof.ts | 9 ++-- src/face/blazeface.ts | 9 ++-- src/face/facemesh.ts | 9 ++-- src/face/faceres.ts | 10 ++--- src/face/iris.ts | 9 ++-- src/face/liveness.ts | 10 ++--- src/face/mobilefacenet.ts | 10 ++--- src/gear/emotion.ts | 9 ++-- src/gear/gear.ts | 9 ++-- src/gear/ssrnet-age.ts | 11 ++--- src/gear/ssrnet-gender.ts | 9 ++-- src/hand/handpose.ts | 12 ++---- src/hand/handtrack.ts | 10 ++--- src/human.ts | 3 ++ src/object/centernet.ts | 6 +-- src/object/nanodet.ts | 6 +-- src/tfjs/load.ts | 80 ++++++++++++++++++++--------------- 28 files changed, 134 insertions(+), 186 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2dad0e4..95ed833c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,10 @@ ## Changelog -### **HEAD -> main** 2022/01/14 mandic00@live.com +### **HEAD -> main** 2022/01/16 mandic00@live.com + + +### **origin/main** 2022/01/15 mandic00@live.com - fix face box and hand tracking when in front of face diff --git a/TODO.md b/TODO.md index 9e3cf8c4..822ca080 100644 --- a/TODO.md +++ b/TODO.md @@ -19,9 +19,6 @@ Experimental support only until support is officially added in Chromium -- Performance issues: - - ### Face Detection Enhanced rotation correction for face detection is not working in NodeJS due to missing kernel op in TFJS @@ -34,4 +31,17 @@ Feature is automatically disabled in NodeJS without user impact ## Pending Release Notes -N/A +- Add global model cache hander using indexdb in browser environments + see `config.cacheModels` setting for details +- Add additional demos + `human-motion` and `human-avatar` +- Updated samples image gallery +- Fix face box detections when face is partially occluded +- Fix face box scaling +- Fix hand tracking when hand is in front of face +- Fix compatibility with `ElectronJS` +- Fix interpolation for some body keypoints +- Updated blazepose calculations +- Changes to blazepose and handlandmarks annotations +- Strong typing for string enums +- Updated `TFJS` diff --git a/demo/typescript/index.js b/demo/typescript/index.js index 5f46142c..5f1bb67c 100644 --- a/demo/typescript/index.js +++ b/demo/typescript/index.js @@ -90,7 +90,7 @@ async function main() { status("loading..."); await human.load(); log("backend:", human.tf.getBackend(), "| available:", human.env.backends); - log("loaded models:" + Object.values(human.models).filter((model) => model !== null).length); + log("loaded models:", Object.values(human.models).filter((model) => model !== null).length); status("initializing..."); await human.warmup(); await webCam(); diff --git a/demo/typescript/index.ts b/demo/typescript/index.ts index 784dd0e0..b68a54f6 100644 --- a/demo/typescript/index.ts +++ b/demo/typescript/index.ts @@ -102,7 +102,7 @@ async function main() { // main entry point status('loading...'); await human.load(); // preload all models log('backend:', human.tf.getBackend(), '| available:', human.env.backends); - log('loaded models:' + Object.values(human.models).filter((model) => model !== null).length); + log('loaded models:', Object.values(human.models).filter((model) => model !== null).length); status('initializing...'); await human.warmup(); // warmup function to initialize backend for future faster detection await webCam(); // start webcam diff --git a/package.json b/package.json index f54e3abf..3407ed34 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@vladmandic/human", - "version": "2.5.8", + "version": "2.6.0", "description": "Human: AI-powered 3D Face Detection & Rotation Tracking, Face Description & Recognition, Body Pose Tracking, 3D Hand & Finger Tracking, Iris Analysis, Age & Gender & Emotion Prediction, Gesture Recognition", "sideEffects": false, "main": "dist/human.node.js", @@ -65,7 +65,7 @@ "@tensorflow/tfjs-layers": "^3.13.0", "@tensorflow/tfjs-node": "^3.13.0", "@tensorflow/tfjs-node-gpu": "^3.13.0", - "@types/node": "^17.0.8", + "@types/node": "^17.0.9", "@types/offscreencanvas": "^2019.6.4", "@typescript-eslint/eslint-plugin": "^5.9.1", "@typescript-eslint/parser": "^5.9.1", @@ -75,14 +75,14 @@ "canvas": "^2.8.0", "dayjs": "^1.10.7", "esbuild": "^0.14.11", - "eslint": "8.6.0", + "eslint": "8.7.0", "eslint-config-airbnb-base": "^15.0.0", "eslint-plugin-html": "^6.2.0", "eslint-plugin-import": "^2.25.4", "eslint-plugin-json": "^3.1.0", "eslint-plugin-node": "^11.1.0", "eslint-plugin-promise": "^6.0.0", - "node-fetch": "^3.1.0", + "node-fetch": "^3.1.1", "rimraf": "^3.0.2", "seedrandom": "^3.0.5", "tslib": "^2.3.1", diff --git a/src/body/blazepose.ts b/src/body/blazepose.ts index 390e3e3a..2482a3d1 100644 --- a/src/body/blazepose.ts +++ b/src/body/blazepose.ts @@ -5,7 +5,7 @@ import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import { constants } from '../tfjs/constants'; -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import type { BodyKeypoint, BodyResult, BodyLandmark, Box, Point, BodyAnnotation } from '../result'; import type { GraphModel, Tensor } from '../tfjs/types'; import type { Config } from '../config'; @@ -33,12 +33,10 @@ const sigmoid = (x) => (1 - (1 / (1 + Math.exp(x)))); export async function loadDetect(config: Config): Promise { if (env.initial) models.detector = null; if (!models.detector && config.body['detector'] && config.body['detector']['modelPath'] || '') { - models.detector = await loadModel(join(config.modelBasePath, config.body['detector']['modelPath'] || '')) as unknown as GraphModel; + models.detector = await loadModel(config.body['detector']['modelPath']); const inputs = Object.values(models.detector.modelSignature['inputs']); inputSize.detector[0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0; inputSize.detector[1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; - if (!models.detector || !models.detector['modelUrl']) log('load model failed:', config.body['detector']['modelPath']); - else if (config.debug) log('load model:', models.detector['modelUrl']); } else if (config.debug && models.detector) log('cached model:', models.detector['modelUrl']); await detect.createAnchors(); return models.detector as GraphModel; @@ -47,12 +45,10 @@ export async function loadDetect(config: Config): Promise { export async function loadPose(config: Config): Promise { if (env.initial) models.landmarks = null; if (!models.landmarks) { - models.landmarks = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; + models.landmarks = await loadModel(config.body.modelPath); const inputs = Object.values(models.landmarks.modelSignature['inputs']); inputSize.landmarks[0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0; inputSize.landmarks[1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; - if (!models.landmarks || !models.landmarks['modelUrl']) log('load model failed:', config.body.modelPath); - else if (config.debug) log('load model:', models.landmarks['modelUrl']); } else if (config.debug) log('cached model:', models.landmarks['modelUrl']); return models.landmarks; } diff --git a/src/body/blazeposedetector.ts b/src/body/blazeposedetector.ts index 2662278c..50fac822 100644 --- a/src/body/blazeposedetector.ts +++ b/src/body/blazeposedetector.ts @@ -85,27 +85,3 @@ export async function decode(boxesTensor: Tensor, logitsTensor: Tensor, config: Object.keys(t).forEach((tensor) => tf.dispose(t[tensor])); return detected; } - -/* -const humanConfig: Partial = { - warmup: 'full' as const, - modelBasePath: '../../models', - cacheSensitivity: 0, - filter: { enabled: false }, - face: { enabled: false }, - hand: { enabled: false }, - object: { enabled: false }, - gesture: { enabled: false }, - body: { - enabled: true, - minConfidence: 0.1, - modelPath: 'blazepose/blazepose-full.json', - detector: { - enabled: false, - modelPath: 'blazepose/blazepose-detector.json', - minConfidence: 0.1, - iouThreshold: 0.1, - }, - }, -}; -*/ diff --git a/src/body/efficientpose.ts b/src/body/efficientpose.ts index cd7f8cb7..0c3c7c8b 100644 --- a/src/body/efficientpose.ts +++ b/src/body/efficientpose.ts @@ -4,7 +4,7 @@ * Based on: [**EfficientPose**](https://github.com/daniegr/EfficientPose) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import * as coords from './efficientposecoords'; @@ -26,11 +26,8 @@ let skipped = Number.MAX_SAFE_INTEGER; export async function load(config: Config): Promise { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.body.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/body/movenet.ts b/src/body/movenet.ts index 95879522..fa97d6dd 100644 --- a/src/body/movenet.ts +++ b/src/body/movenet.ts @@ -4,7 +4,7 @@ * Based on: [**MoveNet**](https://blog.tensorflow.org/2021/05/next-generation-pose-detection-with-movenet-and-tensorflowjs.html) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as box from '../util/box'; import * as tf from '../../dist/tfjs.esm.js'; import * as coords from './movenetcoords'; @@ -35,9 +35,7 @@ export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { fakeOps(['size'], config); - model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); + model = await loadModel(config.body.modelPath); } else if (config.debug) log('cached model:', model['modelUrl']); inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0; if (inputSize < 64) inputSize = 256; diff --git a/src/body/posenet.ts b/src/body/posenet.ts index 0945af06..7154c0eb 100644 --- a/src/body/posenet.ts +++ b/src/body/posenet.ts @@ -4,7 +4,7 @@ * Based on: [**PoseNet**](https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-tensorflow-js-7dd0bc881cd5) */ -import { log, join } from '../util/util'; +import { log } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import type { BodyResult, BodyLandmark, Box } from '../result'; @@ -179,10 +179,7 @@ export async function predict(input: Tensor, config: Config): Promise { - if (!model || env.initial) { - model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model || env.initial) model = await loadModel(config.body.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/config.ts b/src/config.ts index 79aae0d5..9b938a1e 100644 --- a/src/config.ts +++ b/src/config.ts @@ -248,6 +248,11 @@ export interface Config { */ modelBasePath: string, + /** Cache models in IndexDB on first sucessfull load + * default: true if indexdb is available (browsers), false if its not (nodejs) + */ + cacheModels: boolean, + /** Cache sensitivity * - values 0..1 where 0.01 means reset cache if input changed more than 1% * - set to 0 to disable caching @@ -288,6 +293,7 @@ export interface Config { const config: Config = { backend: '', modelBasePath: '', + cacheModels: true, wasmPath: '', debug: true, async: true, diff --git a/src/face/antispoof.ts b/src/face/antispoof.ts index c7cd945a..029370c4 100644 --- a/src/face/antispoof.ts +++ b/src/face/antispoof.ts @@ -2,7 +2,7 @@ * Anti-spoofing model implementation */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import type { Config } from '../config'; import type { GraphModel, Tensor } from '../tfjs/types'; import * as tf from '../../dist/tfjs.esm.js'; @@ -17,11 +17,8 @@ let lastTime = 0; export async function load(config: Config): Promise { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face.antispoof?.modelPath || '')) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.face.antispoof?.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.face.antispoof?.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/face/blazeface.ts b/src/face/blazeface.ts index f3ccff48..4b2f3730 100644 --- a/src/face/blazeface.ts +++ b/src/face/blazeface.ts @@ -3,7 +3,7 @@ * See `facemesh.ts` for entry point */ -import { log, join } from '../util/util'; +import { log } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import * as util from './facemeshutil'; import { loadModel } from '../tfjs/load'; @@ -26,11 +26,8 @@ export const size = () => inputSize; export async function load(config: Config): Promise { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face.detector?.modelPath || '')) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.face.detector?.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.face.detector?.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0; inputSizeT = tf.scalar(inputSize, 'int32') as Tensor; anchors = tf.tensor2d(util.generateAnchors(inputSize)) as Tensor; diff --git a/src/face/facemesh.ts b/src/face/facemesh.ts index f55091af..07fd63d9 100644 --- a/src/face/facemesh.ts +++ b/src/face/facemesh.ts @@ -7,7 +7,7 @@ * - Eye Iris Details: [**MediaPipe Iris**](https://drive.google.com/file/d/1bsWbokp9AklH2ANjCfmjqEzzxO1CNbMu/view) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import { loadModel } from '../tfjs/load'; import * as tf from '../../dist/tfjs.esm.js'; import * as blazeface from './blazeface'; @@ -111,11 +111,8 @@ export async function predict(input: Tensor, config: Config): Promise { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face.mesh?.modelPath || '')) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.face.mesh?.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.face.mesh?.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0; return model; } diff --git a/src/face/faceres.ts b/src/face/faceres.ts index 80753174..ef4a7c86 100644 --- a/src/face/faceres.ts +++ b/src/face/faceres.ts @@ -7,7 +7,7 @@ * Based on: [**HSE-FaceRes**](https://github.com/HSE-asavchenko/HSE_FaceRec_tf) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import { env } from '../util/env'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; @@ -31,13 +31,9 @@ let lastCount = 0; let skipped = Number.MAX_SAFE_INTEGER; export async function load(config: Config): Promise { - const modelUrl = join(config.modelBasePath, config.face.description?.modelPath || ''); if (env.initial) model = null; - if (!model) { - model = await loadModel(modelUrl) as unknown as GraphModel; - if (!model) log('load model failed:', config.face.description?.modelPath || ''); - else if (config.debug) log('load model:', modelUrl); - } else if (config.debug) log('cached model:', modelUrl); + if (!model) model = await loadModel(config.face.description?.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/face/iris.ts b/src/face/iris.ts index 7bdec347..fde0a910 100644 --- a/src/face/iris.ts +++ b/src/face/iris.ts @@ -3,7 +3,7 @@ import * as util from './facemeshutil'; import * as tf from '../../dist/tfjs.esm.js'; import type { Tensor, GraphModel } from '../tfjs/types'; import { env } from '../util/env'; -import { log, join } from '../util/util'; +import { log } from '../util/util'; import { loadModel } from '../tfjs/load'; import type { Config } from '../config'; import type { Point } from '../result'; @@ -30,11 +30,8 @@ const irisLandmarks = { export async function load(config: Config): Promise { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face.iris?.modelPath || '')); - if (!model || !model['modelUrl']) log('load model failed:', config.face.iris?.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.face.iris?.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0; if (inputSize === -1) inputSize = 64; return model; diff --git a/src/face/liveness.ts b/src/face/liveness.ts index d558c708..c2d1d9ce 100644 --- a/src/face/liveness.ts +++ b/src/face/liveness.ts @@ -2,7 +2,8 @@ * Anti-spoofing model implementation */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; +import { loadModel } from '../tfjs/load'; import type { Config } from '../config'; import type { GraphModel, Tensor } from '../tfjs/types'; import * as tf from '../../dist/tfjs.esm.js'; @@ -16,11 +17,8 @@ let lastTime = 0; export async function load(config: Config): Promise { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face.liveness?.modelPath || '')) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.face.liveness?.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.face.liveness?.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/face/mobilefacenet.ts b/src/face/mobilefacenet.ts index 068baf15..d85923e8 100644 --- a/src/face/mobilefacenet.ts +++ b/src/face/mobilefacenet.ts @@ -6,7 +6,7 @@ * Obsolete and replaced by `faceres` that performs age/gender/descriptor analysis */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import type { Tensor, GraphModel } from '../tfjs/types'; @@ -20,13 +20,9 @@ let lastTime = 0; let skipped = Number.MAX_SAFE_INTEGER; export async function load(config: Config): Promise { - const modelUrl = join(config.modelBasePath, config.face['mobilefacenet'].modelPath); if (env.initial) model = null; - if (!model) { - model = await loadModel(modelUrl) as unknown as GraphModel; - if (!model) log('load model failed:', config.face['mobilefacenet'].modelPath); - else if (config.debug) log('load model:', modelUrl); - } else if (config.debug) log('cached model:', modelUrl); + if (!model) model = await loadModel(config.face['mobilefacenet'].modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/gear/emotion.ts b/src/gear/emotion.ts index 2219db6a..77d51c1b 100644 --- a/src/gear/emotion.ts +++ b/src/gear/emotion.ts @@ -5,7 +5,7 @@ */ import type { Emotion } from '../result'; -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import type { Config } from '../config'; import type { GraphModel, Tensor } from '../tfjs/types'; import * as tf from '../../dist/tfjs.esm.js'; @@ -22,11 +22,8 @@ let skipped = Number.MAX_SAFE_INTEGER; export async function load(config: Config): Promise { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face.emotion?.modelPath || '')) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.face.emotion?.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.face.emotion?.modelPath); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/gear/gear.ts b/src/gear/gear.ts index 11396f3e..a82a7310 100644 --- a/src/gear/gear.ts +++ b/src/gear/gear.ts @@ -4,7 +4,7 @@ * Based on: [**GEAR Predictor**](https://github.com/Udolf15/GEAR-Predictor) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import type { Gender, Race } from '../result'; @@ -24,11 +24,8 @@ let skipped = Number.MAX_SAFE_INTEGER; // eslint-disable-next-line @typescript-eslint/no-explicit-any export async function load(config: Config) { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face['gear'].modelPath)) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.face['gear'].modelPath); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.face['gear']); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/gear/ssrnet-age.ts b/src/gear/ssrnet-age.ts index 3b675b5d..904640b0 100644 --- a/src/gear/ssrnet-age.ts +++ b/src/gear/ssrnet-age.ts @@ -4,7 +4,7 @@ * Based on: [**SSR-Net**](https://github.com/shamangary/SSR-Net) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import { env } from '../util/env'; @@ -21,13 +21,8 @@ let skipped = Number.MAX_SAFE_INTEGER; // eslint-disable-next-line @typescript-eslint/no-explicit-any export async function load(config: Config) { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face['ssrnet'].modelPathAge)) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.face['ssrnet'].modelPathAge); - else if (config.debug) log('load model:', model['modelUrl']); - } else { - if (config.debug) log('cached model:', model['modelUrl']); - } + if (!model) model = await loadModel(config.face['ssrnet'].modelPathAge); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/gear/ssrnet-gender.ts b/src/gear/ssrnet-gender.ts index e7c1c82b..f5f9da1a 100644 --- a/src/gear/ssrnet-gender.ts +++ b/src/gear/ssrnet-gender.ts @@ -4,7 +4,7 @@ * Based on: [**SSR-Net**](https://github.com/shamangary/SSR-Net) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import { constants } from '../tfjs/constants'; @@ -25,11 +25,8 @@ const rgb = [0.2989, 0.5870, 0.1140]; // factors for red/green/blue colors when // eslint-disable-next-line @typescript-eslint/no-explicit-any export async function load(config: Config | any) { if (env.initial) model = null; - if (!model) { - model = await loadModel(join(config.modelBasePath, config.face['ssrnet'].modelPathGender)) as unknown as GraphModel; - if (!model || !model['modelUrl']) log('load model failed:', config.face['ssrnet'].modelPathGender); - else if (config.debug) log('load model:', model['modelUrl']); - } else if (config.debug) log('cached model:', model['modelUrl']); + if (!model) model = await loadModel(config.face['ssrnet'].modelPathGender); + else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/hand/handpose.ts b/src/hand/handpose.ts index 719c1f7c..aa24a09f 100644 --- a/src/hand/handpose.ts +++ b/src/hand/handpose.ts @@ -4,7 +4,7 @@ * Based on: [**MediaPipe HandPose**](https://drive.google.com/file/d/1sv4sSb9BSNVZhLzxXJ0jBv9DqD-4jnAz/view) */ -import { log, join } from '../util/util'; +import { log } from '../util/util'; import * as handdetector from './handposedetector'; import * as handpipeline from './handposepipeline'; import * as fingerPose from './fingerpose'; @@ -89,15 +89,9 @@ export async function load(config: Config): Promise<[GraphModel | null, GraphMod } if (!handDetectorModel || !handPoseModel) { [handDetectorModel, handPoseModel] = await Promise.all([ - config.hand.enabled ? loadModel(join(config.modelBasePath, config.hand.detector?.modelPath || '')) as unknown as GraphModel : null, - config.hand.landmarks ? loadModel(join(config.modelBasePath, config.hand.skeleton?.modelPath || '')) as unknown as GraphModel : null, + config.hand.enabled ? loadModel(config.hand.detector?.modelPath) : null, + config.hand.landmarks ? loadModel(config.hand.skeleton?.modelPath) : null, ]); - if (config.hand.enabled) { - if (!handDetectorModel || !handDetectorModel['modelUrl']) log('load model failed:', config.hand.detector?.modelPath || ''); - else if (config.debug) log('load model:', handDetectorModel['modelUrl']); - if (!handPoseModel || !handPoseModel['modelUrl']) log('load model failed:', config.hand.skeleton?.modelPath || ''); - else if (config.debug) log('load model:', handPoseModel['modelUrl']); - } } else { if (config.debug) log('cached model:', handDetectorModel['modelUrl']); if (config.debug) log('cached model:', handPoseModel['modelUrl']); diff --git a/src/hand/handtrack.ts b/src/hand/handtrack.ts index 0734d4d2..71f83755 100644 --- a/src/hand/handtrack.ts +++ b/src/hand/handtrack.ts @@ -6,7 +6,7 @@ * - Hand Tracking: [**HandTracking**](https://github.com/victordibia/handtracking) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as box from '../util/box'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; @@ -75,12 +75,10 @@ export async function loadDetect(config: Config): Promise { // handtrack model has some kernel ops defined in model but those are never referenced and non-existent in tfjs // ideally need to prune the model itself fakeOps(['tensorlistreserve', 'enter', 'tensorlistfromtensor', 'merge', 'loopcond', 'switch', 'exit', 'tensorliststack', 'nextiteration', 'tensorlistsetitem', 'tensorlistgetitem', 'reciprocal', 'shape', 'split', 'where'], config); - models[0] = await loadModel(join(config.modelBasePath, config.hand.detector?.modelPath || '')) as unknown as GraphModel; + models[0] = await loadModel(config.hand.detector?.modelPath); const inputs = Object.values(models[0].modelSignature['inputs']); inputSize[0][0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0; inputSize[0][1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; - if (!models[0] || !models[0]['modelUrl']) log('load model failed:', config.hand.detector?.modelPath); - else if (config.debug) log('load model:', models[0]['modelUrl']); } else if (config.debug) log('cached model:', models[0]['modelUrl']); return models[0]; } @@ -88,12 +86,10 @@ export async function loadDetect(config: Config): Promise { export async function loadSkeleton(config: Config): Promise { if (env.initial) models[1] = null; if (!models[1]) { - models[1] = await loadModel(join(config.modelBasePath, config.hand.skeleton?.modelPath || '')) as unknown as GraphModel; + models[1] = await loadModel(config.hand.skeleton?.modelPath); const inputs = Object.values(models[1].modelSignature['inputs']); inputSize[1][0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0; inputSize[1][1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; - if (!models[1] || !models[1]['modelUrl']) log('load model failed:', config.hand.skeleton?.modelPath); - else if (config.debug) log('load model:', models[1]['modelUrl']); } else if (config.debug) log('cached model:', models[1]['modelUrl']); return models[1]; } diff --git a/src/human.ts b/src/human.ts index aee9cd9b..8730c59f 100644 --- a/src/human.ts +++ b/src/human.ts @@ -11,6 +11,7 @@ import { log, now, mergeDeep, validate } from './util/util'; import { defaults } from './config'; import { env, Env } from './util/env'; +import { setModelLoadOptions } from './tfjs/load'; import * as tf from '../dist/tfjs.esm.js'; import * as app from '../package.json'; import * as backend from './tfjs/backend'; @@ -134,6 +135,8 @@ export class Human { this.config = JSON.parse(JSON.stringify(defaults)); Object.seal(this.config); if (userConfig) this.config = mergeDeep(this.config, userConfig); + this.config.cacheModels = typeof indexedDB !== 'undefined'; + setModelLoadOptions(this.config); this.tf = tf; this.state = 'idle'; this.#numTensors = 0; diff --git a/src/object/centernet.ts b/src/object/centernet.ts index 15497bef..e7cc0b04 100644 --- a/src/object/centernet.ts +++ b/src/object/centernet.ts @@ -4,7 +4,7 @@ * Based on: [**NanoDet**](https://github.com/RangiLyu/nanodet) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import { labels } from './labels'; @@ -23,11 +23,9 @@ export async function load(config: Config): Promise { if (env.initial) model = null; if (!model) { // fakeOps(['floormod'], config); - model = await loadModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel; + model = await loadModel(config.object.modelPath); const inputs = Object.values(model.modelSignature['inputs']); inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; - if (!model || !model['modelUrl']) log('load model failed:', config.object.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/object/nanodet.ts b/src/object/nanodet.ts index b6659eb9..85f8f7c9 100644 --- a/src/object/nanodet.ts +++ b/src/object/nanodet.ts @@ -4,7 +4,7 @@ * Based on: [**MB3-CenterNet**](https://github.com/610265158/mobilenetv3_centernet) */ -import { log, join, now } from '../util/util'; +import { log, now } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import { loadModel } from '../tfjs/load'; import { constants } from '../tfjs/constants'; @@ -24,11 +24,9 @@ const scaleBox = 2.5; // increase box size export async function load(config: Config): Promise { if (!model || env.initial) { - model = await loadModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel; + model = await loadModel(config.object.modelPath); const inputs = Object.values(model.modelSignature['inputs']); inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0; - if (!model || !model['modelUrl']) log('load model failed:', config.object.modelPath); - else if (config.debug) log('load model:', model['modelUrl']); } else if (config.debug) log('cached model:', model['modelUrl']); return model; } diff --git a/src/tfjs/load.ts b/src/tfjs/load.ts index c08f0405..956ade02 100644 --- a/src/tfjs/load.ts +++ b/src/tfjs/load.ts @@ -1,45 +1,57 @@ -import { log, mergeDeep } from '../util/util'; +import { log, join } from '../util/util'; import * as tf from '../../dist/tfjs.esm.js'; import type { GraphModel } from './types'; +import type { Config } from '../config'; -type FetchFunc = (url: RequestInfo, init?: RequestInit) => Promise; -type ProgressFunc = (...args) => void; - -export type LoadOptions = { - appName: string, - autoSave: boolean, - verbose: boolean, - fetchFunc?: FetchFunc, - onProgress?: ProgressFunc, -} - -let options: LoadOptions = { - appName: 'human', - autoSave: true, +const options = { + cacheModels: false, verbose: true, + debug: false, + modelBasePath: '', }; -async function httpHandler(url: RequestInfo, init?: RequestInit): Promise { - if (options.fetchFunc) return options.fetchFunc(url, init); - else log('error: fetch function is not defined'); - return null; +async function httpHandler(url, init?): Promise { + if (options.debug) log('load model fetch:', url, init); + if (typeof fetch === 'undefined') { + log('error loading model: fetch function is not defined:'); + return null; + } + return fetch(url, init); } -const tfLoadOptions = { - onProgress: (...args) => { - if (options.onProgress) options.onProgress(...args); - else if (options.verbose) log('load model progress:', ...args); - }, - fetchFunc: (url: RequestInfo, init?: RequestInit) => { - if (options.verbose) log('load model fetch:', url, init); - if (url.toString().toLowerCase().startsWith('http')) return httpHandler(url, init); - return null; - }, -}; +export function setModelLoadOptions(config: Config) { + options.cacheModels = config.cacheModels; + options.verbose = config.debug; + options.modelBasePath = config.modelBasePath; +} -export async function loadModel(modelUrl: string, loadOptions?: LoadOptions): Promise { - if (loadOptions) options = mergeDeep(loadOptions); - if (!options.fetchFunc && (typeof globalThis.fetch !== 'undefined')) options.fetchFunc = globalThis.fetch; - const model = await tf.loadGraphModel(modelUrl, tfLoadOptions) as unknown as GraphModel; +export async function loadModel(modelPath: string | undefined): Promise { + const modelUrl = join(options.modelBasePath, modelPath || ''); + const modelPathSegments = modelUrl.split('/'); + const cachedModelName = 'indexeddb://' + modelPathSegments[modelPathSegments.length - 1].replace('.json', ''); // generate short model name for cache + const cachedModels = await tf.io.listModels(); // list all models already in cache + const modelCached = options.cacheModels && Object.keys(cachedModels).includes(cachedModelName); // is model found in cache + // create model prototype and decide if load from cache or from original modelurl + const model: GraphModel = new tf.GraphModel(modelCached ? cachedModelName : modelUrl, { fetchFunc: (url, init?) => httpHandler(url, init) }) as unknown as GraphModel; + try { + // @ts-ignore private function + model.findIOHandler(); // decide how to actually load a model + // @ts-ignore private property + if (options.debug) log('model load handler:', model.handler); + // @ts-ignore private property + const artifacts = await model.handler.load(); // load manifest + model.loadSync(artifacts); // load weights + if (options.verbose) log('load model:', model['modelUrl']); + } catch (err) { + log('error loading model:', modelUrl, err); + } + if (options.cacheModels && !modelCached) { // save model to cache + try { + const saveResult = await model.save(cachedModelName); + log('model saved:', cachedModelName, saveResult); + } catch (err) { + log('error saving model:', modelUrl, err); + } + } return model; }