implement model caching using indexdb

pull/356/head
Vladimir Mandic 2022-01-17 11:03:21 -05:00
parent c5911301e9
commit 2c0057cd30
28 changed files with 134 additions and 186 deletions

View File

@ -9,7 +9,10 @@
## Changelog
### **HEAD -> main** 2022/01/14 mandic00@live.com
### **HEAD -> main** 2022/01/16 mandic00@live.com
### **origin/main** 2022/01/15 mandic00@live.com
- fix face box and hand tracking when in front of face

18
TODO.md
View File

@ -19,9 +19,6 @@
Experimental support only until support is officially added in Chromium
- Performance issues:
<https://github.com/tensorflow/tfjs/issues/5689>
### Face Detection
Enhanced rotation correction for face detection is not working in NodeJS due to missing kernel op in TFJS
@ -34,4 +31,17 @@ Feature is automatically disabled in NodeJS without user impact
## Pending Release Notes
N/A
- Add global model cache hander using indexdb in browser environments
see `config.cacheModels` setting for details
- Add additional demos
`human-motion` and `human-avatar`
- Updated samples image gallery
- Fix face box detections when face is partially occluded
- Fix face box scaling
- Fix hand tracking when hand is in front of face
- Fix compatibility with `ElectronJS`
- Fix interpolation for some body keypoints
- Updated blazepose calculations
- Changes to blazepose and handlandmarks annotations
- Strong typing for string enums
- Updated `TFJS`

View File

@ -90,7 +90,7 @@ async function main() {
status("loading...");
await human.load();
log("backend:", human.tf.getBackend(), "| available:", human.env.backends);
log("loaded models:" + Object.values(human.models).filter((model) => model !== null).length);
log("loaded models:", Object.values(human.models).filter((model) => model !== null).length);
status("initializing...");
await human.warmup();
await webCam();

View File

@ -102,7 +102,7 @@ async function main() { // main entry point
status('loading...');
await human.load(); // preload all models
log('backend:', human.tf.getBackend(), '| available:', human.env.backends);
log('loaded models:' + Object.values(human.models).filter((model) => model !== null).length);
log('loaded models:', Object.values(human.models).filter((model) => model !== null).length);
status('initializing...');
await human.warmup(); // warmup function to initialize backend for future faster detection
await webCam(); // start webcam

View File

@ -1,6 +1,6 @@
{
"name": "@vladmandic/human",
"version": "2.5.8",
"version": "2.6.0",
"description": "Human: AI-powered 3D Face Detection & Rotation Tracking, Face Description & Recognition, Body Pose Tracking, 3D Hand & Finger Tracking, Iris Analysis, Age & Gender & Emotion Prediction, Gesture Recognition",
"sideEffects": false,
"main": "dist/human.node.js",
@ -65,7 +65,7 @@
"@tensorflow/tfjs-layers": "^3.13.0",
"@tensorflow/tfjs-node": "^3.13.0",
"@tensorflow/tfjs-node-gpu": "^3.13.0",
"@types/node": "^17.0.8",
"@types/node": "^17.0.9",
"@types/offscreencanvas": "^2019.6.4",
"@typescript-eslint/eslint-plugin": "^5.9.1",
"@typescript-eslint/parser": "^5.9.1",
@ -75,14 +75,14 @@
"canvas": "^2.8.0",
"dayjs": "^1.10.7",
"esbuild": "^0.14.11",
"eslint": "8.6.0",
"eslint": "8.7.0",
"eslint-config-airbnb-base": "^15.0.0",
"eslint-plugin-html": "^6.2.0",
"eslint-plugin-import": "^2.25.4",
"eslint-plugin-json": "^3.1.0",
"eslint-plugin-node": "^11.1.0",
"eslint-plugin-promise": "^6.0.0",
"node-fetch": "^3.1.0",
"node-fetch": "^3.1.1",
"rimraf": "^3.0.2",
"seedrandom": "^3.0.5",
"tslib": "^2.3.1",

View File

@ -5,7 +5,7 @@
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import { constants } from '../tfjs/constants';
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import type { BodyKeypoint, BodyResult, BodyLandmark, Box, Point, BodyAnnotation } from '../result';
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
@ -33,12 +33,10 @@ const sigmoid = (x) => (1 - (1 / (1 + Math.exp(x))));
export async function loadDetect(config: Config): Promise<GraphModel> {
if (env.initial) models.detector = null;
if (!models.detector && config.body['detector'] && config.body['detector']['modelPath'] || '') {
models.detector = await loadModel(join(config.modelBasePath, config.body['detector']['modelPath'] || '')) as unknown as GraphModel;
models.detector = await loadModel(config.body['detector']['modelPath']);
const inputs = Object.values(models.detector.modelSignature['inputs']);
inputSize.detector[0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0;
inputSize.detector[1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
if (!models.detector || !models.detector['modelUrl']) log('load model failed:', config.body['detector']['modelPath']);
else if (config.debug) log('load model:', models.detector['modelUrl']);
} else if (config.debug && models.detector) log('cached model:', models.detector['modelUrl']);
await detect.createAnchors();
return models.detector as GraphModel;
@ -47,12 +45,10 @@ export async function loadDetect(config: Config): Promise<GraphModel> {
export async function loadPose(config: Config): Promise<GraphModel> {
if (env.initial) models.landmarks = null;
if (!models.landmarks) {
models.landmarks = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel;
models.landmarks = await loadModel(config.body.modelPath);
const inputs = Object.values(models.landmarks.modelSignature['inputs']);
inputSize.landmarks[0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0;
inputSize.landmarks[1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
if (!models.landmarks || !models.landmarks['modelUrl']) log('load model failed:', config.body.modelPath);
else if (config.debug) log('load model:', models.landmarks['modelUrl']);
} else if (config.debug) log('cached model:', models.landmarks['modelUrl']);
return models.landmarks;
}

View File

@ -85,27 +85,3 @@ export async function decode(boxesTensor: Tensor, logitsTensor: Tensor, config:
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
return detected;
}
/*
const humanConfig: Partial<Config> = {
warmup: 'full' as const,
modelBasePath: '../../models',
cacheSensitivity: 0,
filter: { enabled: false },
face: { enabled: false },
hand: { enabled: false },
object: { enabled: false },
gesture: { enabled: false },
body: {
enabled: true,
minConfidence: 0.1,
modelPath: 'blazepose/blazepose-full.json',
detector: {
enabled: false,
modelPath: 'blazepose/blazepose-detector.json',
minConfidence: 0.1,
iouThreshold: 0.1,
},
},
};
*/

View File

@ -4,7 +4,7 @@
* Based on: [**EfficientPose**](https://github.com/daniegr/EfficientPose)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import * as coords from './efficientposecoords';
@ -26,11 +26,8 @@ let skipped = Number.MAX_SAFE_INTEGER;
export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.body.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -4,7 +4,7 @@
* Based on: [**MoveNet**](https://blog.tensorflow.org/2021/05/next-generation-pose-detection-with-movenet-and-tensorflowjs.html)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as box from '../util/box';
import * as tf from '../../dist/tfjs.esm.js';
import * as coords from './movenetcoords';
@ -35,9 +35,7 @@ export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
fakeOps(['size'], config);
model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
model = await loadModel(config.body.modelPath);
} else if (config.debug) log('cached model:', model['modelUrl']);
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
if (inputSize < 64) inputSize = 256;

View File

@ -4,7 +4,7 @@
* Based on: [**PoseNet**](https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-tensorflow-js-7dd0bc881cd5)
*/
import { log, join } from '../util/util';
import { log } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import type { BodyResult, BodyLandmark, Box } from '../result';
@ -179,10 +179,7 @@ export async function predict(input: Tensor, config: Config): Promise<BodyResult
}
export async function load(config: Config): Promise<GraphModel> {
if (!model || env.initial) {
model = await loadModel(join(config.modelBasePath, config.body.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.body.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model || env.initial) model = await loadModel(config.body.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -248,6 +248,11 @@ export interface Config {
*/
modelBasePath: string,
/** Cache models in IndexDB on first sucessfull load
* default: true if indexdb is available (browsers), false if its not (nodejs)
*/
cacheModels: boolean,
/** Cache sensitivity
* - values 0..1 where 0.01 means reset cache if input changed more than 1%
* - set to 0 to disable caching
@ -288,6 +293,7 @@ export interface Config {
const config: Config = {
backend: '',
modelBasePath: '',
cacheModels: true,
wasmPath: '',
debug: true,
async: true,

View File

@ -2,7 +2,7 @@
* Anti-spoofing model implementation
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import type { Config } from '../config';
import type { GraphModel, Tensor } from '../tfjs/types';
import * as tf from '../../dist/tfjs.esm.js';
@ -17,11 +17,8 @@ let lastTime = 0;
export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face.antispoof?.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.face.antispoof?.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.face.antispoof?.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -3,7 +3,7 @@
* See `facemesh.ts` for entry point
*/
import { log, join } from '../util/util';
import { log } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import * as util from './facemeshutil';
import { loadModel } from '../tfjs/load';
@ -26,11 +26,8 @@ export const size = () => inputSize;
export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face.detector?.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.face.detector?.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.face.detector?.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
inputSizeT = tf.scalar(inputSize, 'int32') as Tensor;
anchors = tf.tensor2d(util.generateAnchors(inputSize)) as Tensor;

View File

@ -7,7 +7,7 @@
* - Eye Iris Details: [**MediaPipe Iris**](https://drive.google.com/file/d/1bsWbokp9AklH2ANjCfmjqEzzxO1CNbMu/view)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import { loadModel } from '../tfjs/load';
import * as tf from '../../dist/tfjs.esm.js';
import * as blazeface from './blazeface';
@ -111,11 +111,8 @@ export async function predict(input: Tensor, config: Config): Promise<FaceResult
export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face.mesh?.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.face.mesh?.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.face.mesh?.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
return model;
}

View File

@ -7,7 +7,7 @@
* Based on: [**HSE-FaceRes**](https://github.com/HSE-asavchenko/HSE_FaceRec_tf)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import { env } from '../util/env';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
@ -31,13 +31,9 @@ let lastCount = 0;
let skipped = Number.MAX_SAFE_INTEGER;
export async function load(config: Config): Promise<GraphModel> {
const modelUrl = join(config.modelBasePath, config.face.description?.modelPath || '');
if (env.initial) model = null;
if (!model) {
model = await loadModel(modelUrl) as unknown as GraphModel;
if (!model) log('load model failed:', config.face.description?.modelPath || '');
else if (config.debug) log('load model:', modelUrl);
} else if (config.debug) log('cached model:', modelUrl);
if (!model) model = await loadModel(config.face.description?.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -3,7 +3,7 @@ import * as util from './facemeshutil';
import * as tf from '../../dist/tfjs.esm.js';
import type { Tensor, GraphModel } from '../tfjs/types';
import { env } from '../util/env';
import { log, join } from '../util/util';
import { log } from '../util/util';
import { loadModel } from '../tfjs/load';
import type { Config } from '../config';
import type { Point } from '../result';
@ -30,11 +30,8 @@ const irisLandmarks = {
export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face.iris?.modelPath || ''));
if (!model || !model['modelUrl']) log('load model failed:', config.face.iris?.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.face.iris?.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
if (inputSize === -1) inputSize = 64;
return model;

View File

@ -2,7 +2,8 @@
* Anti-spoofing model implementation
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import { loadModel } from '../tfjs/load';
import type { Config } from '../config';
import type { GraphModel, Tensor } from '../tfjs/types';
import * as tf from '../../dist/tfjs.esm.js';
@ -16,11 +17,8 @@ let lastTime = 0;
export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face.liveness?.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.face.liveness?.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.face.liveness?.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -6,7 +6,7 @@
* Obsolete and replaced by `faceres` that performs age/gender/descriptor analysis
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import type { Tensor, GraphModel } from '../tfjs/types';
@ -20,13 +20,9 @@ let lastTime = 0;
let skipped = Number.MAX_SAFE_INTEGER;
export async function load(config: Config): Promise<GraphModel> {
const modelUrl = join(config.modelBasePath, config.face['mobilefacenet'].modelPath);
if (env.initial) model = null;
if (!model) {
model = await loadModel(modelUrl) as unknown as GraphModel;
if (!model) log('load model failed:', config.face['mobilefacenet'].modelPath);
else if (config.debug) log('load model:', modelUrl);
} else if (config.debug) log('cached model:', modelUrl);
if (!model) model = await loadModel(config.face['mobilefacenet'].modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -5,7 +5,7 @@
*/
import type { Emotion } from '../result';
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import type { Config } from '../config';
import type { GraphModel, Tensor } from '../tfjs/types';
import * as tf from '../../dist/tfjs.esm.js';
@ -22,11 +22,8 @@ let skipped = Number.MAX_SAFE_INTEGER;
export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face.emotion?.modelPath || '')) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.face.emotion?.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.face.emotion?.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -4,7 +4,7 @@
* Based on: [**GEAR Predictor**](https://github.com/Udolf15/GEAR-Predictor)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import type { Gender, Race } from '../result';
@ -24,11 +24,8 @@ let skipped = Number.MAX_SAFE_INTEGER;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export async function load(config: Config) {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face['gear'].modelPath)) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.face['gear'].modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.face['gear']);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -4,7 +4,7 @@
* Based on: [**SSR-Net**](https://github.com/shamangary/SSR-Net)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import { env } from '../util/env';
@ -21,13 +21,8 @@ let skipped = Number.MAX_SAFE_INTEGER;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export async function load(config: Config) {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face['ssrnet'].modelPathAge)) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.face['ssrnet'].modelPathAge);
else if (config.debug) log('load model:', model['modelUrl']);
} else {
if (config.debug) log('cached model:', model['modelUrl']);
}
if (!model) model = await loadModel(config.face['ssrnet'].modelPathAge);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -4,7 +4,7 @@
* Based on: [**SSR-Net**](https://github.com/shamangary/SSR-Net)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import { constants } from '../tfjs/constants';
@ -25,11 +25,8 @@ const rgb = [0.2989, 0.5870, 0.1140]; // factors for red/green/blue colors when
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export async function load(config: Config | any) {
if (env.initial) model = null;
if (!model) {
model = await loadModel(join(config.modelBasePath, config.face['ssrnet'].modelPathGender)) as unknown as GraphModel;
if (!model || !model['modelUrl']) log('load model failed:', config.face['ssrnet'].modelPathGender);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
if (!model) model = await loadModel(config.face['ssrnet'].modelPathGender);
else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -4,7 +4,7 @@
* Based on: [**MediaPipe HandPose**](https://drive.google.com/file/d/1sv4sSb9BSNVZhLzxXJ0jBv9DqD-4jnAz/view)
*/
import { log, join } from '../util/util';
import { log } from '../util/util';
import * as handdetector from './handposedetector';
import * as handpipeline from './handposepipeline';
import * as fingerPose from './fingerpose';
@ -89,15 +89,9 @@ export async function load(config: Config): Promise<[GraphModel | null, GraphMod
}
if (!handDetectorModel || !handPoseModel) {
[handDetectorModel, handPoseModel] = await Promise.all([
config.hand.enabled ? loadModel(join(config.modelBasePath, config.hand.detector?.modelPath || '')) as unknown as GraphModel : null,
config.hand.landmarks ? loadModel(join(config.modelBasePath, config.hand.skeleton?.modelPath || '')) as unknown as GraphModel : null,
config.hand.enabled ? loadModel(config.hand.detector?.modelPath) : null,
config.hand.landmarks ? loadModel(config.hand.skeleton?.modelPath) : null,
]);
if (config.hand.enabled) {
if (!handDetectorModel || !handDetectorModel['modelUrl']) log('load model failed:', config.hand.detector?.modelPath || '');
else if (config.debug) log('load model:', handDetectorModel['modelUrl']);
if (!handPoseModel || !handPoseModel['modelUrl']) log('load model failed:', config.hand.skeleton?.modelPath || '');
else if (config.debug) log('load model:', handPoseModel['modelUrl']);
}
} else {
if (config.debug) log('cached model:', handDetectorModel['modelUrl']);
if (config.debug) log('cached model:', handPoseModel['modelUrl']);

View File

@ -6,7 +6,7 @@
* - Hand Tracking: [**HandTracking**](https://github.com/victordibia/handtracking)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as box from '../util/box';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
@ -75,12 +75,10 @@ export async function loadDetect(config: Config): Promise<GraphModel> {
// handtrack model has some kernel ops defined in model but those are never referenced and non-existent in tfjs
// ideally need to prune the model itself
fakeOps(['tensorlistreserve', 'enter', 'tensorlistfromtensor', 'merge', 'loopcond', 'switch', 'exit', 'tensorliststack', 'nextiteration', 'tensorlistsetitem', 'tensorlistgetitem', 'reciprocal', 'shape', 'split', 'where'], config);
models[0] = await loadModel(join(config.modelBasePath, config.hand.detector?.modelPath || '')) as unknown as GraphModel;
models[0] = await loadModel(config.hand.detector?.modelPath);
const inputs = Object.values(models[0].modelSignature['inputs']);
inputSize[0][0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0;
inputSize[0][1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
if (!models[0] || !models[0]['modelUrl']) log('load model failed:', config.hand.detector?.modelPath);
else if (config.debug) log('load model:', models[0]['modelUrl']);
} else if (config.debug) log('cached model:', models[0]['modelUrl']);
return models[0];
}
@ -88,12 +86,10 @@ export async function loadDetect(config: Config): Promise<GraphModel> {
export async function loadSkeleton(config: Config): Promise<GraphModel> {
if (env.initial) models[1] = null;
if (!models[1]) {
models[1] = await loadModel(join(config.modelBasePath, config.hand.skeleton?.modelPath || '')) as unknown as GraphModel;
models[1] = await loadModel(config.hand.skeleton?.modelPath);
const inputs = Object.values(models[1].modelSignature['inputs']);
inputSize[1][0] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[1].size) : 0;
inputSize[1][1] = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
if (!models[1] || !models[1]['modelUrl']) log('load model failed:', config.hand.skeleton?.modelPath);
else if (config.debug) log('load model:', models[1]['modelUrl']);
} else if (config.debug) log('cached model:', models[1]['modelUrl']);
return models[1];
}

View File

@ -11,6 +11,7 @@
import { log, now, mergeDeep, validate } from './util/util';
import { defaults } from './config';
import { env, Env } from './util/env';
import { setModelLoadOptions } from './tfjs/load';
import * as tf from '../dist/tfjs.esm.js';
import * as app from '../package.json';
import * as backend from './tfjs/backend';
@ -134,6 +135,8 @@ export class Human {
this.config = JSON.parse(JSON.stringify(defaults));
Object.seal(this.config);
if (userConfig) this.config = mergeDeep(this.config, userConfig);
this.config.cacheModels = typeof indexedDB !== 'undefined';
setModelLoadOptions(this.config);
this.tf = tf;
this.state = 'idle';
this.#numTensors = 0;

View File

@ -4,7 +4,7 @@
* Based on: [**NanoDet**](https://github.com/RangiLyu/nanodet)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import { labels } from './labels';
@ -23,11 +23,9 @@ export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
// fakeOps(['floormod'], config);
model = await loadModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel;
model = await loadModel(config.object.modelPath);
const inputs = Object.values(model.modelSignature['inputs']);
inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
if (!model || !model['modelUrl']) log('load model failed:', config.object.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -4,7 +4,7 @@
* Based on: [**MB3-CenterNet**](https://github.com/610265158/mobilenetv3_centernet)
*/
import { log, join, now } from '../util/util';
import { log, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { loadModel } from '../tfjs/load';
import { constants } from '../tfjs/constants';
@ -24,11 +24,9 @@ const scaleBox = 2.5; // increase box size
export async function load(config: Config): Promise<GraphModel> {
if (!model || env.initial) {
model = await loadModel(join(config.modelBasePath, config.object.modelPath || '')) as unknown as GraphModel;
model = await loadModel(config.object.modelPath);
const inputs = Object.values(model.modelSignature['inputs']);
inputSize = Array.isArray(inputs) ? parseInt(inputs[0].tensorShape.dim[2].size) : 0;
if (!model || !model['modelUrl']) log('load model failed:', config.object.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}

View File

@ -1,45 +1,57 @@
import { log, mergeDeep } from '../util/util';
import { log, join } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import type { GraphModel } from './types';
import type { Config } from '../config';
type FetchFunc = (url: RequestInfo, init?: RequestInit) => Promise<Response>;
type ProgressFunc = (...args) => void;
export type LoadOptions = {
appName: string,
autoSave: boolean,
verbose: boolean,
fetchFunc?: FetchFunc,
onProgress?: ProgressFunc,
}
let options: LoadOptions = {
appName: 'human',
autoSave: true,
const options = {
cacheModels: false,
verbose: true,
debug: false,
modelBasePath: '',
};
async function httpHandler(url: RequestInfo, init?: RequestInit): Promise<Response | null> {
if (options.fetchFunc) return options.fetchFunc(url, init);
else log('error: fetch function is not defined');
return null;
async function httpHandler(url, init?): Promise<Response | null> {
if (options.debug) log('load model fetch:', url, init);
if (typeof fetch === 'undefined') {
log('error loading model: fetch function is not defined:');
return null;
}
return fetch(url, init);
}
const tfLoadOptions = {
onProgress: (...args) => {
if (options.onProgress) options.onProgress(...args);
else if (options.verbose) log('load model progress:', ...args);
},
fetchFunc: (url: RequestInfo, init?: RequestInit) => {
if (options.verbose) log('load model fetch:', url, init);
if (url.toString().toLowerCase().startsWith('http')) return httpHandler(url, init);
return null;
},
};
export function setModelLoadOptions(config: Config) {
options.cacheModels = config.cacheModels;
options.verbose = config.debug;
options.modelBasePath = config.modelBasePath;
}
export async function loadModel(modelUrl: string, loadOptions?: LoadOptions): Promise<GraphModel> {
if (loadOptions) options = mergeDeep(loadOptions);
if (!options.fetchFunc && (typeof globalThis.fetch !== 'undefined')) options.fetchFunc = globalThis.fetch;
const model = await tf.loadGraphModel(modelUrl, tfLoadOptions) as unknown as GraphModel;
export async function loadModel(modelPath: string | undefined): Promise<GraphModel> {
const modelUrl = join(options.modelBasePath, modelPath || '');
const modelPathSegments = modelUrl.split('/');
const cachedModelName = 'indexeddb://' + modelPathSegments[modelPathSegments.length - 1].replace('.json', ''); // generate short model name for cache
const cachedModels = await tf.io.listModels(); // list all models already in cache
const modelCached = options.cacheModels && Object.keys(cachedModels).includes(cachedModelName); // is model found in cache
// create model prototype and decide if load from cache or from original modelurl
const model: GraphModel = new tf.GraphModel(modelCached ? cachedModelName : modelUrl, { fetchFunc: (url, init?) => httpHandler(url, init) }) as unknown as GraphModel;
try {
// @ts-ignore private function
model.findIOHandler(); // decide how to actually load a model
// @ts-ignore private property
if (options.debug) log('model load handler:', model.handler);
// @ts-ignore private property
const artifacts = await model.handler.load(); // load manifest
model.loadSync(artifacts); // load weights
if (options.verbose) log('load model:', model['modelUrl']);
} catch (err) {
log('error loading model:', modelUrl, err);
}
if (options.cacheModels && !modelCached) { // save model to cache
try {
const saveResult = await model.save(cachedModelName);
log('model saved:', cachedModelName, saveResult);
} catch (err) {
log('error saving model:', modelUrl, err);
}
}
return model;
}