add face.mesh.keepInvalid config flag

pull/293/head
Vladimir Mandic 2022-05-22 08:50:51 -04:00
parent cc18f16b2e
commit 465176e2dd
9 changed files with 833 additions and 740 deletions

View File

@ -1,19 +1,27 @@
const fs = require('fs'); const fs = require('fs');
const process = require('process');
// eslint-disable-next-line import/no-extraneous-dependencies, no-unused-vars, @typescript-eslint/no-unused-vars // eslint-disable-next-line import/no-extraneous-dependencies, no-unused-vars, @typescript-eslint/no-unused-vars
const tf = require('@tensorflow/tfjs-node'); // in nodejs environments tfjs-node is required to be loaded before human const tf = require('@tensorflow/tfjs-node'); // in nodejs environments tfjs-node is required to be loaded before human
// const faceapi = require('@vladmandic/face-api'); // use this when human is installed as module (majority of use cases) // const faceapi = require('@vladmandic/face-api'); // use this when human is installed as module (majority of use cases)
const Human = require('../../dist/human.node.js'); // use this when using human in dev mode const Human = require('../../dist/human.node.js'); // use this when using human in dev mode
async function main(inputFile) { const humanConfig = {
const human = new Human.Human(); // create instance of human using default configuration // add any custom config here
};
async function detect(inputFile) {
const human = new Human.Human(humanConfig); // create instance of human using default configuration
await human.load(); // optional as models would be loaded on-demand first time they are required await human.load(); // optional as models would be loaded on-demand first time they are required
await human.warmup(); // optional as model warmup is performed on-demand first time its executed await human.warmup(); // optional as model warmup is performed on-demand first time its executed
const buffer = fs.readFileSync(inputFile); // read file data into buffer const buffer = fs.readFileSync(inputFile); // read file data into buffer
const tensor = human.tf.node.decodeImage(buffer); // decode jpg data const tensor = human.tf.node.decodeImage(buffer); // decode jpg data
// eslint-disable-next-line no-console
console.log('loaded input file:', inputFile, 'resolution:', tensor.shape);
const result = await human.detect(tensor); // run detection; will initialize backend and on-demand load models const result = await human.detect(tensor); // run detection; will initialize backend and on-demand load models
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.log(result); console.log(result);
} }
main('samples/in/ai-body.jpg'); if (process.argv.length === 3) detect(process.argv[2]); // if input file is provided as cmdline parameter use it
else detect('samples/in/ai-body.jpg'); // else use built-in test inputfile

View File

@ -53,19 +53,19 @@
"tensorflow" "tensorflow"
], ],
"devDependencies": { "devDependencies": {
"@microsoft/api-extractor": "^7.24.0", "@microsoft/api-extractor": "^7.24.1",
"@tensorflow/tfjs": "^3.17.0", "@tensorflow/tfjs": "^3.18.0",
"@tensorflow/tfjs-backend-cpu": "^3.17.0", "@tensorflow/tfjs-backend-cpu": "^3.18.0",
"@tensorflow/tfjs-backend-wasm": "^3.17.0", "@tensorflow/tfjs-backend-wasm": "^3.18.0",
"@tensorflow/tfjs-backend-webgl": "^3.17.0", "@tensorflow/tfjs-backend-webgl": "^3.18.0",
"@tensorflow/tfjs-backend-webgpu": "0.0.1-alpha.10", "@tensorflow/tfjs-backend-webgpu": "0.0.1-alpha.11",
"@tensorflow/tfjs-converter": "^3.17.0", "@tensorflow/tfjs-converter": "^3.18.0",
"@tensorflow/tfjs-core": "^3.17.0", "@tensorflow/tfjs-core": "^3.18.0",
"@tensorflow/tfjs-data": "^3.17.0", "@tensorflow/tfjs-data": "^3.18.0",
"@tensorflow/tfjs-layers": "^3.17.0", "@tensorflow/tfjs-layers": "^3.18.0",
"@tensorflow/tfjs-node": "^3.17.0", "@tensorflow/tfjs-node": "^3.18.0",
"@tensorflow/tfjs-node-gpu": "^3.17.0", "@tensorflow/tfjs-node-gpu": "^3.18.0",
"@types/node": "^17.0.34", "@types/node": "^17.0.35",
"@types/offscreencanvas": "^2019.6.4", "@types/offscreencanvas": "^2019.6.4",
"@typescript-eslint/eslint-plugin": "^5.25.0", "@typescript-eslint/eslint-plugin": "^5.25.0",
"@typescript-eslint/parser": "^5.25.0", "@typescript-eslint/parser": "^5.25.0",
@ -75,7 +75,7 @@
"canvas": "^2.9.1", "canvas": "^2.9.1",
"dayjs": "^1.11.2", "dayjs": "^1.11.2",
"esbuild": "^0.14.39", "esbuild": "^0.14.39",
"eslint": "8.15.0", "eslint": "8.16.0",
"eslint-config-airbnb-base": "^15.0.0", "eslint-config-airbnb-base": "^15.0.0",
"eslint-plugin-html": "^6.2.0", "eslint-plugin-html": "^6.2.0",
"eslint-plugin-import": "^2.26.0", "eslint-plugin-import": "^2.26.0",

View File

@ -35,7 +35,10 @@ export interface FaceDetectorConfig extends GenericConfig {
} }
/** Mesh part of face configuration */ /** Mesh part of face configuration */
export interface FaceMeshConfig extends GenericConfig {} export interface FaceMeshConfig extends GenericConfig {
/** Keep detected faces that cannot be verified using facemesh */
keepInvalid: boolean
}
/** Iris part of face configuration */ /** Iris part of face configuration */
export interface FaceIrisConfig extends GenericConfig {} export interface FaceIrisConfig extends GenericConfig {}
@ -352,6 +355,7 @@ const config: Config = {
mesh: { mesh: {
enabled: true, enabled: true,
modelPath: 'facemesh.json', modelPath: 'facemesh.json',
keepInvalid: false,
}, },
attention: { attention: {
enabled: false, enabled: false,

View File

@ -94,6 +94,19 @@ export async function predict(input: Tensor, config: Config): Promise<FaceResult
let rawCoords = await coordsReshaped.array(); let rawCoords = await coordsReshaped.array();
if (face.faceScore < (config.face.detector?.minConfidence || 1)) { // low confidence in detected mesh if (face.faceScore < (config.face.detector?.minConfidence || 1)) { // low confidence in detected mesh
box.confidence = face.faceScore; // reset confidence of cached box box.confidence = face.faceScore; // reset confidence of cached box
if (config.face.mesh?.keepInvalid) {
face.box = util.clampBox(box, input);
face.boxRaw = util.getRawBox(box, input);
face.score = face.boxScore;
face.mesh = box.landmarks.map((pt) => [
((box.startPoint[0] + box.endPoint[0])) / 2 + ((box.endPoint[0] + box.startPoint[0]) * pt[0] / blazeface.size()),
((box.startPoint[1] + box.endPoint[1])) / 2 + ((box.endPoint[1] + box.startPoint[1]) * pt[1] / blazeface.size()),
]);
face.meshRaw = face.mesh.map((pt) => [pt[0] / (input.shape[2] || 0), pt[1] / (input.shape[1] || 0), (pt[2] || 0) / inputSize]);
for (const key of Object.keys(coords.blazeFaceLandmarks)) {
face.annotations[key] = [face.mesh[coords.blazeFaceLandmarks[key] as number]]; // add annotations
}
}
} else { } else {
if (config.face.attention?.enabled) { if (config.face.attention?.enabled) {
rawCoords = await attention.augment(rawCoords, results); // augment iris results using attention model results rawCoords = await attention.augment(rawCoords, results); // augment iris results using attention model results

View File

@ -121,7 +121,9 @@ export const rotatePoint = (homogeneousCoordinate, rotationMatrix) => [dot(homog
export const xyDistanceBetweenPoints = (a, b) => Math.sqrt(((a[0] - b[0]) ** 2) + ((a[1] - b[1]) ** 2)); export const xyDistanceBetweenPoints = (a, b) => Math.sqrt(((a[0] - b[0]) ** 2) + ((a[1] - b[1]) ** 2));
export function generateAnchors(inputSize) { export function generateAnchors(inputSize) {
const spec = { strides: [inputSize / 16, inputSize / 8], anchors: [2, 6] }; const spec = inputSize === 192
? { strides: [4], anchors: [1] } // facemesh-detector
: { strides: [inputSize / 16, inputSize / 8], anchors: [2, 6] }; // blazeface
const anchors: Array<[number, number]> = []; const anchors: Array<[number, number]> = [];
for (let i = 0; i < spec.strides.length; i++) { for (let i = 0; i < spec.strides.length; i++) {
const stride = spec.strides[i]; const stride = spec.strides[i];

View File

@ -22,7 +22,8 @@ export function setModelLoadOptions(config: Config) {
} }
export async function loadModel(modelPath: string | undefined): Promise<GraphModel> { export async function loadModel(modelPath: string | undefined): Promise<GraphModel> {
const modelUrl = join(options.modelBasePath, modelPath || ''); let modelUrl = join(options.modelBasePath, modelPath || '');
if (!modelUrl.toLowerCase().endsWith('.json')) modelUrl += '.json';
const modelPathSegments = modelUrl.split('/'); const modelPathSegments = modelUrl.split('/');
const cachedModelName = 'indexeddb://' + modelPathSegments[modelPathSegments.length - 1].replace('.json', ''); // generate short model name for cache const cachedModelName = 'indexeddb://' + modelPathSegments[modelPathSegments.length - 1].replace('.json', ''); // generate short model name for cache
const cachedModels = await tf.io.listModels(); // list all models already in cache const cachedModels = await tf.io.listModels(); // list all models already in cache

View File

@ -1,24 +1,24 @@
2022-05-18 17:41:21 INFO:  Application: {"name":"@vladmandic/human","version":"2.7.2"} 2022-05-22 08:49:40 INFO:  Application: {"name":"@vladmandic/human","version":"2.7.2"}
2022-05-18 17:41:21 INFO:  Environment: {"profile":"production","config":".build.json","package":"package.json","tsconfig":true,"eslintrc":true,"git":true} 2022-05-22 08:49:40 INFO:  Environment: {"profile":"production","config":".build.json","package":"package.json","tsconfig":true,"eslintrc":true,"git":true}
2022-05-18 17:41:21 INFO:  Toolchain: {"build":"0.7.3","esbuild":"0.14.39","typescript":"4.6.4","typedoc":"0.22.15","eslint":"8.15.0"} 2022-05-22 08:49:40 INFO:  Toolchain: {"build":"0.7.3","esbuild":"0.14.39","typescript":"4.6.4","typedoc":"0.22.15","eslint":"8.16.0"}
2022-05-18 17:41:21 INFO:  Build: {"profile":"production","steps":["clean","compile","typings","typedoc","lint","changelog"]} 2022-05-22 08:49:40 INFO:  Build: {"profile":"production","steps":["clean","compile","typings","typedoc","lint","changelog"]}
2022-05-18 17:41:21 STATE: Clean: {"locations":["dist/*","types/lib/*","typedoc/*"]} 2022-05-22 08:49:40 STATE: Clean: {"locations":["dist/*","types/lib/*","typedoc/*"]}
2022-05-18 17:41:21 STATE: Compile: {"name":"tfjs/nodejs/cpu","format":"cjs","platform":"node","input":"tfjs/tf-node.ts","output":"dist/tfjs.esm.js","files":1,"inputBytes":102,"outputBytes":595} 2022-05-22 08:49:40 STATE: Compile: {"name":"tfjs/nodejs/cpu","format":"cjs","platform":"node","input":"tfjs/tf-node.ts","output":"dist/tfjs.esm.js","files":1,"inputBytes":102,"outputBytes":595}
2022-05-18 17:41:21 STATE: Compile: {"name":"human/nodejs/cpu","format":"cjs","platform":"node","input":"src/human.ts","output":"dist/human.node.js","files":72,"inputBytes":606782,"outputBytes":297946} 2022-05-22 08:49:40 STATE: Compile: {"name":"human/nodejs/cpu","format":"cjs","platform":"node","input":"src/human.ts","output":"dist/human.node.js","files":72,"inputBytes":607902,"outputBytes":298472}
2022-05-18 17:41:21 STATE: Compile: {"name":"tfjs/nodejs/gpu","format":"cjs","platform":"node","input":"tfjs/tf-node-gpu.ts","output":"dist/tfjs.esm.js","files":1,"inputBytes":110,"outputBytes":599} 2022-05-22 08:49:40 STATE: Compile: {"name":"tfjs/nodejs/gpu","format":"cjs","platform":"node","input":"tfjs/tf-node-gpu.ts","output":"dist/tfjs.esm.js","files":1,"inputBytes":110,"outputBytes":599}
2022-05-18 17:41:21 STATE: Compile: {"name":"human/nodejs/gpu","format":"cjs","platform":"node","input":"src/human.ts","output":"dist/human.node-gpu.js","files":72,"inputBytes":606786,"outputBytes":297950} 2022-05-22 08:49:40 STATE: Compile: {"name":"human/nodejs/gpu","format":"cjs","platform":"node","input":"src/human.ts","output":"dist/human.node-gpu.js","files":72,"inputBytes":607906,"outputBytes":298476}
2022-05-18 17:41:21 STATE: Compile: {"name":"tfjs/nodejs/wasm","format":"cjs","platform":"node","input":"tfjs/tf-node-wasm.ts","output":"dist/tfjs.esm.js","files":1,"inputBytes":149,"outputBytes":651} 2022-05-22 08:49:40 STATE: Compile: {"name":"tfjs/nodejs/wasm","format":"cjs","platform":"node","input":"tfjs/tf-node-wasm.ts","output":"dist/tfjs.esm.js","files":1,"inputBytes":149,"outputBytes":651}
2022-05-18 17:41:21 STATE: Compile: {"name":"human/nodejs/wasm","format":"cjs","platform":"node","input":"src/human.ts","output":"dist/human.node-wasm.js","files":72,"inputBytes":606838,"outputBytes":298000} 2022-05-22 08:49:40 STATE: Compile: {"name":"human/nodejs/wasm","format":"cjs","platform":"node","input":"src/human.ts","output":"dist/human.node-wasm.js","files":72,"inputBytes":607958,"outputBytes":298526}
2022-05-18 17:41:21 STATE: Compile: {"name":"tfjs/browser/version","format":"esm","platform":"browser","input":"tfjs/tf-version.ts","output":"dist/tfjs.version.js","files":1,"inputBytes":1069,"outputBytes":358} 2022-05-22 08:49:40 STATE: Compile: {"name":"tfjs/browser/version","format":"esm","platform":"browser","input":"tfjs/tf-version.ts","output":"dist/tfjs.version.js","files":1,"inputBytes":1069,"outputBytes":358}
2022-05-18 17:41:21 STATE: Compile: {"name":"tfjs/browser/esm/nobundle","format":"esm","platform":"browser","input":"tfjs/tf-browser.ts","output":"dist/tfjs.esm.js","files":2,"inputBytes":1032,"outputBytes":583} 2022-05-22 08:49:40 STATE: Compile: {"name":"tfjs/browser/esm/nobundle","format":"esm","platform":"browser","input":"tfjs/tf-browser.ts","output":"dist/tfjs.esm.js","files":2,"inputBytes":1032,"outputBytes":583}
2022-05-18 17:41:21 STATE: Compile: {"name":"human/browser/esm/nobundle","format":"esm","platform":"browser","input":"src/human.ts","output":"dist/human.esm-nobundle.js","files":72,"inputBytes":606770,"outputBytes":296859} 2022-05-22 08:49:40 STATE: Compile: {"name":"human/browser/esm/nobundle","format":"esm","platform":"browser","input":"src/human.ts","output":"dist/human.esm-nobundle.js","files":72,"inputBytes":607890,"outputBytes":297382}
2022-05-18 17:41:21 STATE: Compile: {"name":"tfjs/browser/esm/custom","format":"esm","platform":"browser","input":"tfjs/tf-custom.ts","output":"dist/tfjs.esm.js","files":1,"inputBytes":110,"outputBytes":1352584} 2022-05-22 08:49:40 STATE: Compile: {"name":"tfjs/browser/esm/custom","format":"esm","platform":"browser","input":"tfjs/tf-custom.ts","output":"dist/tfjs.esm.js","files":1,"inputBytes":110,"outputBytes":1352913}
2022-05-18 17:41:21 STATE: Compile: {"name":"human/browser/iife/bundle","format":"iife","platform":"browser","input":"src/human.ts","output":"dist/human.js","files":72,"inputBytes":1958771,"outputBytes":1648490} 2022-05-22 08:49:40 STATE: Compile: {"name":"human/browser/iife/bundle","format":"iife","platform":"browser","input":"src/human.ts","output":"dist/human.js","files":72,"inputBytes":1960220,"outputBytes":1649341}
2022-05-18 17:41:21 STATE: Compile: {"name":"human/browser/esm/bundle","format":"esm","platform":"browser","input":"src/human.ts","output":"dist/human.esm.js","files":72,"inputBytes":1958771,"outputBytes":2131466} 2022-05-22 08:49:40 STATE: Compile: {"name":"human/browser/esm/bundle","format":"esm","platform":"browser","input":"src/human.ts","output":"dist/human.esm.js","files":72,"inputBytes":1960220,"outputBytes":2132978}
2022-05-18 17:41:26 STATE: Typings: {"input":"src/human.ts","output":"types/lib","files":114} 2022-05-22 08:49:45 STATE: Typings: {"input":"src/human.ts","output":"types/lib","files":114}
2022-05-18 17:41:28 STATE: TypeDoc: {"input":"src/human.ts","output":"typedoc","objects":73,"generated":true} 2022-05-22 08:49:47 STATE: TypeDoc: {"input":"src/human.ts","output":"typedoc","objects":73,"generated":true}
2022-05-18 17:41:28 STATE: Compile: {"name":"demo/typescript","format":"esm","platform":"browser","input":"demo/typescript/index.ts","output":"demo/typescript/index.js","files":1,"inputBytes":5967,"outputBytes":2980} 2022-05-22 08:49:47 STATE: Compile: {"name":"demo/typescript","format":"esm","platform":"browser","input":"demo/typescript/index.ts","output":"demo/typescript/index.js","files":1,"inputBytes":5967,"outputBytes":2980}
2022-05-18 17:41:28 STATE: Compile: {"name":"demo/faceid","format":"esm","platform":"browser","input":"demo/faceid/index.ts","output":"demo/faceid/index.js","files":2,"inputBytes":15174,"outputBytes":7820} 2022-05-22 08:49:47 STATE: Compile: {"name":"demo/faceid","format":"esm","platform":"browser","input":"demo/faceid/index.ts","output":"demo/faceid/index.js","files":2,"inputBytes":15174,"outputBytes":7820}
2022-05-18 17:41:36 STATE: Lint: {"locations":["*.json","src/**/*.ts","test/**/*.js","demo/**/*.js"],"files":104,"errors":0,"warnings":0} 2022-05-22 08:49:55 STATE: Lint: {"locations":["*.json","src/**/*.ts","test/**/*.js","demo/**/*.js"],"files":104,"errors":0,"warnings":0}
2022-05-18 17:41:36 STATE: ChangeLog: {"repository":"https://github.com/vladmandic/human","branch":"main","output":"CHANGELOG.md"} 2022-05-22 08:49:56 STATE: ChangeLog: {"repository":"https://github.com/vladmandic/human","branch":"main","output":"CHANGELOG.md"}
2022-05-18 17:41:36 INFO:  Done... 2022-05-22 08:49:56 INFO:  Done...

File diff suppressed because it is too large Load Diff

89
types/human.d.ts vendored
View File

@ -286,16 +286,12 @@ declare function copyModel(sourceURL: string, destURL: string): Promise<ModelArt
*/ */
declare type DataId = object; declare type DataId = object;
declare type DataToGPUOptions = DataToGPUWebGLOption | DataToGPUWebGPUOption; declare type DataToGPUOptions = DataToGPUWebGLOption;
declare interface DataToGPUWebGLOption { declare interface DataToGPUWebGLOption {
customTexShape?: [number, number]; customTexShape?: [number, number];
} }
declare interface DataToGPUWebGPUOption {
customBufSize?: number;
}
/** @docalias 'float32'|'int32'|'bool'|'complex64'|'string' */ /** @docalias 'float32'|'int32'|'bool'|'complex64'|'string' */
declare type DataType = keyof DataTypeMap; declare type DataType = keyof DataTypeMap;
@ -573,6 +569,8 @@ export declare interface FaceLivenessConfig extends GenericConfig {
/** Mesh part of face configuration */ /** Mesh part of face configuration */
export declare interface FaceMeshConfig extends GenericConfig { export declare interface FaceMeshConfig extends GenericConfig {
/** Keep detected faces that cannot be verified using facemesh */
keepInvalid: boolean;
} }
/** Face results /** Face results
@ -714,15 +712,38 @@ export declare type FingerDirection = 'verticalUp' | 'verticalDown' | 'horizonta
* @param modelArtifacts a object containing model topology (i.e., parsed from * @param modelArtifacts a object containing model topology (i.e., parsed from
* the JSON format). * the JSON format).
* @param weightSpecs An array of `WeightsManifestEntry` objects describing the * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
* names, shapes, types, and quantization of the weight data. * names, shapes, types, and quantization of the weight data. Optional.
* @param weightData A single `ArrayBuffer` containing the weight data, * @param weightData A single `ArrayBuffer` containing the weight data,
* concatenated in the order described by the weightSpecs. * concatenated in the order described by the weightSpecs. Optional.
* @param trainingConfig Model training configuration. Optional. * @param trainingConfig Model training configuration. Optional.
* *
* @returns A passthrough `IOHandler` that simply loads the provided data. * @returns A passthrough `IOHandler` that simply loads the provided data.
*/ */
declare function fromMemory(modelArtifacts: {} | ModelArtifacts, weightSpecs?: WeightsManifestEntry[], weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandler; declare function fromMemory(modelArtifacts: {} | ModelArtifacts, weightSpecs?: WeightsManifestEntry[], weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandler;
/**
* Creates an IOHandler that loads model artifacts from memory.
*
* When used in conjunction with `tf.loadLayersModel`, an instance of
* `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
*
* ```js
* const model = await tf.loadLayersModel(tf.io.fromMemory(
* modelTopology, weightSpecs, weightData));
* ```
*
* @param modelArtifacts a object containing model topology (i.e., parsed from
* the JSON format).
* @param weightSpecs An array of `WeightsManifestEntry` objects describing the
* names, shapes, types, and quantization of the weight data. Optional.
* @param weightData A single `ArrayBuffer` containing the weight data,
* concatenated in the order described by the weightSpecs. Optional.
* @param trainingConfig Model training configuration. Optional.
*
* @returns A passthrough `IOHandlerSync` that simply loads the provided data.
*/
declare function fromMemorySync(modelArtifacts: {} | ModelArtifacts, weightSpecs?: WeightsManifestEntry[], weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandlerSync;
export declare type Gender = 'male' | 'female' | 'unknown'; export declare type Gender = 'male' | 'female' | 'unknown';
/** Generic config type inherited by all module types */ /** Generic config type inherited by all module types */
@ -807,7 +828,7 @@ declare interface GPUData {
* *
* @doc {heading: 'Models', subheading: 'Classes'} * @doc {heading: 'Models', subheading: 'Classes'}
*/ */
export declare class GraphModel implements InferenceModel { export declare class GraphModel<ModelURL extends Url = string | io.IOHandler> implements InferenceModel {
private modelUrl; private modelUrl;
private loadOptions; private loadOptions;
private executor; private executor;
@ -834,13 +855,13 @@ export declare class GraphModel implements InferenceModel {
* @param onProgress Optional, progress callback function, fired periodically * @param onProgress Optional, progress callback function, fired periodically
* before the load is completed. * before the load is completed.
*/ */
constructor(modelUrl: string | io.IOHandler, loadOptions?: io.LoadOptions); constructor(modelUrl: ModelURL, loadOptions?: io.LoadOptions);
private findIOHandler; private findIOHandler;
/** /**
* Loads the model and weight files, construct the in memory weight map and * Loads the model and weight files, construct the in memory weight map and
* compile the inference graph. * compile the inference graph.
*/ */
load(): Promise<boolean>; load(): UrlIOHandler<ModelURL> extends io.IOHandlerSync ? boolean : Promise<boolean>;
/** /**
* Synchronously construct the in memory weight map and * Synchronously construct the in memory weight map and
* compile the inference graph. Also initialize hashtable if any. * compile the inference graph. Also initialize hashtable if any.
@ -1378,12 +1399,14 @@ declare namespace io {
decodeWeights, decodeWeights,
encodeWeights, encodeWeights,
fromMemory, fromMemory,
fromMemorySync,
getLoadHandlers, getLoadHandlers,
getModelArtifactsForJSON, getModelArtifactsForJSON,
getModelArtifactsInfoForJSON, getModelArtifactsInfoForJSON,
getSaveHandlers, getSaveHandlers,
http, http,
IOHandler, IOHandler,
IOHandlerSync,
isHTTPScheme, isHTTPScheme,
LoadHandler, LoadHandler,
LoadOptions, LoadOptions,
@ -1404,7 +1427,8 @@ declare namespace io {
weightsLoaderFactory, weightsLoaderFactory,
WeightsManifestConfig, WeightsManifestConfig,
WeightsManifestEntry, WeightsManifestEntry,
withSaveHandler withSaveHandler,
withSaveHandlerSync
} }
} }
@ -1419,6 +1443,10 @@ declare interface IOHandler {
load?: LoadHandler; load?: LoadHandler;
} }
declare type IOHandlerSync = {
[K in keyof IOHandler]: Syncify<IOHandler[K]>;
};
declare type IORouter = (url: string | string[], loadOptions?: LoadOptions) => IOHandler; declare type IORouter = (url: string | string[], loadOptions?: LoadOptions) => IOHandler;
/** iris gesture type */ /** iris gesture type */
@ -1985,6 +2013,8 @@ export declare interface PersonResult {
/** generic point as [x, y, z?] */ /** generic point as [x, y, z?] */
export declare type Point = [number, number, number?]; export declare type Point = [number, number, number?];
declare type PromiseFunction = (...args: unknown[]) => Promise<unknown>;
export declare type Race = 'white' | 'black' | 'asian' | 'indian' | 'other'; export declare type Race = 'white' | 'black' | 'asian' | 'indian' | 'other';
export declare enum Rank { export declare enum Rank {
@ -2182,6 +2212,8 @@ declare interface SingleValueMap {
string: string; string: string;
} }
declare type Syncify<T extends PromiseFunction> = T extends (...args: infer Args) => Promise<infer R> ? (...args: Args) => R : never;
export declare namespace Tensor { } export declare namespace Tensor { }
/** /**
@ -2265,6 +2297,9 @@ export declare class Tensor<R extends Rank = Rank> {
* For WebGL backend, the data will be stored on a densely packed texture. * For WebGL backend, the data will be stored on a densely packed texture.
* This means that the texture will use the RGBA channels to store value. * This means that the texture will use the RGBA channels to store value.
* *
* For WebGPU backend, the data will be stored on a buffer. There is no
* parameter, so can not use an user defined size to create the buffer.
*
* @param options: * @param options:
* For WebGL, * For WebGL,
* - customTexShape: Optional. If set, will use the user defined * - customTexShape: Optional. If set, will use the user defined
@ -2277,6 +2312,15 @@ export declare class Tensor<R extends Rank = Rank> {
* texture: WebGLTexture, * texture: WebGLTexture,
* texShape: [number, number] // [height, width] * texShape: [number, number] // [height, width]
* } * }
*
* For WebGPU backend, a GPUData contains the new buffer and
* its information.
* {
* tensorRef: The tensor that is associated with this buffer,
* buffer: GPUBuffer,
* bufSize: number
* }
*
* Remember to dispose the GPUData after it is used by * Remember to dispose the GPUData after it is used by
* `res.tensorRef.dispose()`. * `res.tensorRef.dispose()`.
* *
@ -2397,6 +2441,10 @@ declare interface TrainingConfig {
declare type TypedArray = Float32Array | Int32Array | Uint8Array; declare type TypedArray = Float32Array | Int32Array | Uint8Array;
declare type Url = string | io.IOHandler | io.IOHandlerSync;
declare type UrlIOHandler<T extends Url> = T extends string ? io.IOHandler : T;
declare function validate(instance: Human): Promise<void>; declare function validate(instance: Human): Promise<void>;
/** /**
@ -2536,8 +2584,25 @@ declare interface WeightsManifestGroupConfig {
* ``` * ```
* *
* @param saveHandler A function that accepts a `ModelArtifacts` and returns a * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
* `SaveResult`. * promise that resolves to a `SaveResult`.
*/ */
declare function withSaveHandler(saveHandler: (artifacts: ModelArtifacts) => Promise<SaveResult>): IOHandler; declare function withSaveHandler(saveHandler: (artifacts: ModelArtifacts) => Promise<SaveResult>): IOHandler;
/**
* Creates an IOHandlerSync that passes saved model artifacts to a callback.
*
* ```js
* function handleSave(artifacts) {
* // ... do something with the artifacts ...
* return {modelArtifactsInfo: {...}, ...};
* }
*
* const saveResult = model.save(tf.io.withSaveHandler(handleSave));
* ```
*
* @param saveHandler A function that accepts a `ModelArtifacts` and returns a
* `SaveResult`.
*/
declare function withSaveHandlerSync(saveHandler: (artifacts: ModelArtifacts) => SaveResult): IOHandlerSync;
export { } export { }