cache frequent tf constants

pull/280/head
Vladimir Mandic 2021-11-16 18:31:07 -05:00
parent ce2e669dad
commit 98a1817d13
16 changed files with 78 additions and 61 deletions

View File

@ -3,6 +3,7 @@
*/
import * as tf from '../../dist/tfjs.esm.js';
import * as constants from '../tfjs/constants';
import { log, join, now } from '../util/util';
import type { BodyKeypoint, BodyResult, Box, Point } from '../result';
import type { GraphModel, Tensor } from '../tfjs/types';
@ -71,7 +72,7 @@ async function prepareImage(input: Tensor): Promise<Tensor> {
];
t.pad = tf.pad(input, padding);
t.resize = tf.image.resizeBilinear(t.pad, [inputSize[1][0], inputSize[1][1]]);
const final = tf.div(t.resize, 255);
const final = tf.div(t.resize, constants.tf255);
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
return final;
}

View File

@ -7,6 +7,7 @@
import { log, join, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import * as coords from './efficientposecoords';
import * as constants from '../tfjs/constants';
import type { BodyResult, Point } from '../result';
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
@ -33,19 +34,22 @@ export async function load(config: Config): Promise<GraphModel> {
}
// performs argmax and max functions on a 2d tensor
function max2d(inputs, minScore) {
async function max2d(inputs, minScore) {
const [width, height] = inputs.shape;
return tf.tidy(() => {
const reshaped = tf.reshape(inputs, [height * width]); // combine all data
const newScore = tf.max(reshaped, 0).dataSync()[0]; // get highest score // inside tf.tidy
if (newScore > minScore) { // skip coordinate calculation is score is too low
const coordinates = tf.argMax(reshaped, 0);
const x = tf.mod(coordinates, width).dataSync()[0]; // inside tf.tidy
const y = tf.div(coordinates, tf.scalar(width, 'int32')).dataSync()[0]; // inside tf.tidy
return [x, y, newScore];
}
return [0, 0, newScore];
});
const reshaped = tf.reshape(inputs, [height * width]); // combine all data
const max = tf.max(reshaped, 0);
const newScore = (await max.data())[0]; // get highest score
tf.dispose([reshaped, max]);
if (newScore > minScore) { // skip coordinate calculation is score is too low
const coordinates = tf.argMax(reshaped, 0);
const mod = tf.mod(coordinates, width);
const x = (await mod.data())[0];
const div = tf.div(coordinates, tf.scalar(width, 'int32'));
const y = (await div.data())[0];
tf.dispose([mod, div]);
return [x, y, newScore];
}
return [0, 0, newScore];
}
export async function predict(image: Tensor, config: Config): Promise<BodyResult[]> {
@ -60,8 +64,8 @@ export async function predict(image: Tensor, config: Config): Promise<BodyResult
const tensor = tf.tidy(() => {
if (!model?.inputs[0].shape) return null;
const resize = tf.image.resizeBilinear(image, [model.inputs[0].shape[2], model.inputs[0].shape[1]], false);
const enhance = tf.mul(resize, 2);
const norm = enhance.sub(1);
const enhance = tf.mul(resize, constants.tf2);
const norm = tf.sub(enhance, constants.tf1);
return norm;
});
@ -80,7 +84,7 @@ export async function predict(image: Tensor, config: Config): Promise<BodyResult
// process each unstacked tensor as a separate body part
for (let id = 0; id < stack.length; id++) {
// actual processing to get coordinates and score
const [x, y, partScore] = max2d(stack[id], config.body.minConfidence);
const [x, y, partScore] = await max2d(stack[id], config.body.minConfidence);
if (partScore > (config.body?.minConfidence || 0)) {
cache.keypoints.push({
score: Math.round(100 * partScore) / 100,

View File

@ -39,21 +39,17 @@ export async function load(config: Config): Promise<GraphModel> {
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
if (inputSize === -1) inputSize = 256;
return model;
}
async function parseSinglePose(res, config, image, inputBox) {
async function parseSinglePose(res, config, image) {
const kpt = res[0][0];
const keypoints: Array<BodyKeypoint> = [];
let score = 0;
for (let id = 0; id < kpt.length; id++) {
score = kpt[id][2];
if (score > config.body.minConfidence) {
const positionRaw: Point = [
(inputBox[3] - inputBox[1]) * kpt[id][1] + inputBox[1],
(inputBox[2] - inputBox[0]) * kpt[id][0] + inputBox[0],
];
const positionRaw: Point = [kpt[id][1], kpt[id][0]];
keypoints.push({
score: Math.round(100 * score) / 100,
part: coords.kpt[id],
@ -84,7 +80,7 @@ async function parseSinglePose(res, config, image, inputBox) {
return bodies;
}
async function parseMultiPose(res, config, image, inputBox) {
async function parseMultiPose(res, config, image) {
const bodies: Array<BodyResult> = [];
for (let id = 0; id < res[0].length; id++) {
const kpt = res[0][id];
@ -94,10 +90,7 @@ async function parseMultiPose(res, config, image, inputBox) {
for (let i = 0; i < 17; i++) {
const score = kpt[3 * i + 2];
if (score > config.body.minConfidence) {
const positionRaw: Point = [
(inputBox[3] - inputBox[1]) * kpt[3 * i + 1] + inputBox[1],
(inputBox[2] - inputBox[0]) * kpt[3 * i + 0] + inputBox[0],
];
const positionRaw: Point = [kpt[3 * i + 1], kpt[3 * i + 0]];
keypoints.push({
part: coords.kpt[i],
score: Math.round(100 * score) / 100,
@ -181,8 +174,8 @@ export async function predict(input: Tensor, config: Config): Promise<BodyResult
cache.last = now();
const res = await t.res.array();
cache.bodies = (t.res.shape[2] === 17)
? await parseSinglePose(res, config, input, [0, 0, 1, 1])
: await parseMultiPose(res, config, input, [0, 0, 1, 1]);
? await parseSinglePose(res, config, input)
: await parseMultiPose(res, config, input);
for (const body of cache.bodies) {
fix.rescaleBody(body, [input.shape[2] || 1, input.shape[1] || 1]);
fix.jitter(body.keypoints);

View File

@ -6,6 +6,7 @@
import { log, join } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import * as util from './facemeshutil';
import * as constants from '../tfjs/constants';
import type { Config } from '../config';
import type { Tensor, GraphModel } from '../tfjs/types';
import { env } from '../util/env';
@ -15,6 +16,7 @@ const keypointsCount = 6;
let model: GraphModel | null;
let anchors: Tensor | null = null;
let inputSize = 0;
let inputSizeT: Tensor | null = null;
export const size = () => inputSize;
@ -26,7 +28,8 @@ export async function load(config: Config): Promise<GraphModel> {
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
anchors = tf.tensor2d(util.generateAnchors(inputSize));
inputSizeT = tf.scalar(inputSize, 'int32') as Tensor;
anchors = tf.tensor2d(util.generateAnchors(inputSize)) as Tensor;
return model;
}
@ -35,13 +38,13 @@ function decodeBounds(boxOutputs) {
t.boxStarts = tf.slice(boxOutputs, [0, 1], [-1, 2]);
t.centers = tf.add(t.boxStarts, anchors);
t.boxSizes = tf.slice(boxOutputs, [0, 3], [-1, 2]);
t.boxSizesNormalized = tf.div(t.boxSizes, inputSize);
t.centersNormalized = tf.div(t.centers, inputSize);
t.halfBoxSize = tf.div(t.boxSizesNormalized, 2);
t.boxSizesNormalized = tf.div(t.boxSizes, inputSizeT);
t.centersNormalized = tf.div(t.centers, inputSizeT);
t.halfBoxSize = tf.div(t.boxSizesNormalized, constants.tf2);
t.starts = tf.sub(t.centersNormalized, t.halfBoxSize);
t.ends = tf.add(t.centersNormalized, t.halfBoxSize);
t.startNormalized = tf.mul(t.starts, inputSize);
t.endNormalized = tf.mul(t.ends, inputSize);
t.startNormalized = tf.mul(t.starts, inputSizeT);
t.endNormalized = tf.mul(t.ends, inputSizeT);
const boxes = tf.concat2d([t.startNormalized, t.endNormalized], 1);
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
return boxes;
@ -53,8 +56,8 @@ export async function getBoxes(inputImage: Tensor, config: Config) {
const t: Record<string, Tensor> = {};
t.resized = tf.image.resizeBilinear(inputImage, [inputSize, inputSize]);
t.div = tf.div(t.resized, 127.5);
t.normalized = tf.sub(t.div, 0.5);
t.div = tf.div(t.resized, constants.tf127);
t.normalized = tf.sub(t.div, constants.tf05);
const res = model?.execute(t.normalized) as Tensor[];
if (Array.isArray(res)) { // are we using tfhub or pinto converted model?
const sorted = res.sort((a, b) => a.size - b.size);

View File

@ -126,7 +126,6 @@ export async function load(config: Config): Promise<GraphModel> {
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
inputSize = model.inputs[0].shape ? model.inputs[0].shape[2] : 0;
if (inputSize === -1) inputSize = 64;
return model;
}

View File

@ -5,6 +5,7 @@
import * as tf from '../../dist/tfjs.esm.js';
import * as coords from './facemeshcoords';
import * as constants from '../tfjs/constants';
import type { Box, Point } from '../result';
import { env } from '../util/env';
@ -40,7 +41,7 @@ export const cutBoxFromImageAndResize = (box, image, cropSize) => {
const h = image.shape[1];
const w = image.shape[2];
const crop = tf.image.cropAndResize(image, [[box.startPoint[1] / h, box.startPoint[0] / w, box.endPoint[1] / h, box.endPoint[0] / w]], [0], cropSize);
const norm = tf.div(crop, 255);
const norm = tf.div(crop, constants.tf255);
tf.dispose(crop);
return norm;
};

View File

@ -8,10 +8,11 @@
*/
import { log, join, now } from '../util/util';
import { env } from '../util/env';
import * as tf from '../../dist/tfjs.esm.js';
import * as constants from '../tfjs/constants';
import type { Tensor, GraphModel } from '../tfjs/types';
import type { Config } from '../config';
import { env } from '../util/env';
let model: GraphModel | null;
const last: Array<{
@ -40,7 +41,7 @@ export function enhance(input): Tensor {
const tensor = (input.image || input.tensor || input) as Tensor; // input received from detector is already normalized to 0..1, input is also assumed to be straightened
if (!model?.inputs[0].shape) return tensor; // model has no shape so no point continuing
const crop = tf.image.resizeBilinear(tensor, [model.inputs[0].shape[2], model.inputs[0].shape[1]], false);
const norm = tf.mul(crop, 255);
const norm = tf.mul(crop, constants.tf255);
tf.dispose(crop);
return norm;
/*

View File

@ -9,6 +9,7 @@ import type { Config } from '../config';
import type { GraphModel, Tensor } from '../tfjs/types';
import * as tf from '../../dist/tfjs.esm.js';
import { env } from '../util/env';
import * as constants from '../tfjs/constants';
const annotations = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'];
let model: GraphModel | null;
@ -17,9 +18,6 @@ let lastCount = 0;
let lastTime = 0;
let skipped = Number.MAX_SAFE_INTEGER;
// tuning values
const rgb = [0.2989, 0.5870, 0.1140]; // factors for red/green/blue colors when converting to grayscale
export async function load(config: Config): Promise<GraphModel> {
if (env.initial) model = null;
if (!model) {
@ -47,14 +45,16 @@ export async function predict(image: Tensor, config: Config, idx, count): Promis
t.resize = tf.image.resizeBilinear(image, [inputSize, inputSize], false);
// const box = [[0.15, 0.15, 0.85, 0.85]]; // empyrical values for top, left, bottom, right
// const resize = tf.image.cropAndResize(image, box, [0], [inputSize, inputSize]);
[t.red, t.green, t.blue] = tf.split(t.resize, 3, 3);
// [t.red, t.green, t.blue] = tf.split(t.resize, 3, 3);
// weighted rgb to grayscale: https://www.mathworks.com/help/matlab/ref/rgb2gray.html
t.redNorm = tf.mul(t.red, rgb[0]);
t.greenNorm = tf.mul(t.green, rgb[1]);
t.blueNorm = tf.mul(t.blue, rgb[2]);
t.grayscale = tf.addN([t.redNorm, t.greenNorm, t.blueNorm]);
t.grayscaleSub = tf.sub(t.grayscale, 0.5);
t.grayscaleMul = tf.mul(t.grayscaleSub, 2);
// t.redNorm = tf.mul(t.red, rgb[0]);
// t.greenNorm = tf.mul(t.green, rgb[1]);
// t.blueNorm = tf.mul(t.blue, rgb[2]);
// t.grayscale = tf.addN([t.redNorm, t.greenNorm, t.blueNorm]);
t.channels = tf.mul(t.resize, constants.rgb);
t.grayscale = tf.sum(t.channels, 3, true);
t.grayscaleSub = tf.sub(t.grayscale, constants.tf05);
t.grayscaleMul = tf.mul(t.grayscaleSub, constants.tf2);
t.emotion = model?.execute(t.grayscaleMul) as Tensor; // result is already in range 0..1, no need for additional activation
lastTime = now();
const data = await t.emotion.data();

View File

@ -6,9 +6,10 @@
import { log, join, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import { env } from '../util/env';
import * as constants from '../tfjs/constants';
import type { Config } from '../config';
import type { GraphModel, Tensor } from '../tfjs/types';
import { env } from '../util/env';
let model: GraphModel | null;
const last: Array<{ age: number }> = [];
@ -43,7 +44,7 @@ export async function predict(image: Tensor, config: Config, idx, count): Promis
if (!model?.inputs || !model.inputs[0] || !model.inputs[0].shape) return;
const t: Record<string, Tensor> = {};
t.resize = tf.image.resizeBilinear(image, [model.inputs[0].shape[2], model.inputs[0].shape[1]], false);
t.enhance = tf.mul(t.resize, 255);
t.enhance = tf.mul(t.resize, constants.tf255);
const obj = { age: 0 };
if (config.face['ssrnet'].enabled) t.age = model.execute(t.enhance) as Tensor;
if (t.age) {

View File

@ -6,6 +6,7 @@
import { log, join, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import * as constants from '../tfjs/constants';
import type { Config } from '../config';
import type { GraphModel, Tensor } from '../tfjs/types';
import { env } from '../util/env';
@ -50,7 +51,7 @@ export async function predict(image: Tensor, config: Config, idx, count): Promis
const greenNorm = tf.mul(green, rgb[1]);
const blueNorm = tf.mul(blue, rgb[2]);
const grayscale = tf.addN([redNorm, greenNorm, blueNorm]);
const normalize = tf.mul(tf.sub(grayscale, 0.5), 2); // range grayscale:-1..1
const normalize = tf.mul(tf.sub(grayscale, constants.tf05), 2); // range grayscale:-1..1
return normalize;
});
const obj = { gender: '', genderScore: 0 };

View File

@ -6,6 +6,7 @@
import * as tf from '../../dist/tfjs.esm.js';
import * as util from './handposeutil';
import * as anchors from './handposeanchors';
import * as constants from '../tfjs/constants';
import type { Tensor, GraphModel } from '../tfjs/types';
import type { Point } from '../result';
@ -55,8 +56,8 @@ export class HandDetector {
async predict(input, config): Promise<{ startPoint: Point; endPoint: Point, palmLandmarks: Point[]; confidence: number }[]> {
const t: Record<string, Tensor> = {};
t.resize = tf.image.resizeBilinear(input, [this.inputSize, this.inputSize]);
t.div = tf.div(t.resize, 127.5);
t.image = tf.sub(t.div, 1);
t.div = tf.div(t.resize, constants.tf127);
t.image = tf.sub(t.div, constants.tf1);
t.batched = this.model.execute(t.image) as Tensor;
t.predictions = tf.squeeze(t.batched);
t.slice = tf.slice(t.predictions, [0, 0], [-1, 1]);

View File

@ -6,6 +6,7 @@
import * as tf from '../../dist/tfjs.esm.js';
import * as util from './handposeutil';
import type * as detector from './handposedetector';
import * as constants from '../tfjs/constants';
import type { Tensor, GraphModel } from '../tfjs/types';
import { env } from '../util/env';
import { now } from '../util/util';
@ -120,7 +121,7 @@ export class HandPipeline {
const rotationMatrix = util.buildRotationMatrix(-angle, palmCenter);
const newBox = useFreshBox ? this.getBoxForPalmLandmarks(currentBox.palmLandmarks, rotationMatrix) : currentBox;
const croppedInput = util.cutBoxFromImageAndResize(newBox, rotatedImage, [this.inputSize, this.inputSize]);
const handImage = tf.div(croppedInput, 255);
const handImage = tf.div(croppedInput, constants.tf255);
tf.dispose(croppedInput);
tf.dispose(rotatedImage);
const [confidenceT, keypoints] = this.handPoseModel.execute(handImage) as Array<Tensor>;

View File

@ -15,6 +15,7 @@ import type { Config } from '../config';
import { env } from '../util/env';
import * as fingerPose from './fingerpose';
import { fakeOps } from '../tfjs/backend';
import * as constants from '../tfjs/constants';
const models: [GraphModel | null, GraphModel | null] = [null, null];
const modelOutputNodes = ['StatefulPartitionedCall/Postprocessor/Slice', 'StatefulPartitionedCall/Postprocessor/ExpandDims_1'];
@ -154,8 +155,7 @@ async function detectFingers(input: Tensor, h: HandDetectResult, config: Config)
if (input && models[1] && config.hand.landmarks && h.score > (config.hand.minConfidence || 0)) {
const t: Record<string, Tensor> = {};
t.crop = tf.image.cropAndResize(input, [h.boxCrop], [0], [inputSize[1][0], inputSize[1][1]], 'bilinear');
t.cast = tf.cast(t.crop, 'float32');
t.div = tf.div(t.cast, 255);
t.div = tf.div(t.crop, constants.tf255);
[t.score, t.keypoints] = models[1].execute(t.div, ['Identity_1', 'Identity']) as Tensor[];
const rawScore = (await t.score.data())[0];
const score = (100 - Math.trunc(100 / (1 + Math.exp(rawScore)))) / 100; // reverse sigmoid value

View File

@ -6,6 +6,7 @@
import { log, join, now } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import * as constants from '../tfjs/constants';
import { labels } from './labels';
import type { ObjectResult, Box } from '../result';
import type { GraphModel, Tensor } from '../tfjs/types';
@ -117,7 +118,7 @@ export async function predict(image: Tensor, config: Config): Promise<ObjectResu
return new Promise(async (resolve) => {
const outputSize = [image.shape[2], image.shape[1]];
const resize = tf.image.resizeBilinear(image, [model.inputSize, model.inputSize], false);
const norm = tf.div(resize, 255);
const norm = tf.div(resize, constants.tf255);
const transpose = norm.transpose([0, 3, 1, 2]);
tf.dispose(norm);
tf.dispose(resize);

View File

@ -9,6 +9,7 @@
import { log, join } from '../util/util';
import * as tf from '../../dist/tfjs.esm.js';
import * as image from '../image/image';
import * as constants from '../tfjs/constants';
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
import { env } from '../util/env';
@ -39,7 +40,7 @@ export async function process(input: Input, background: Input | undefined, confi
t.resize = tf.image.resizeBilinear(inputImage.tensor, [model.inputs[0].shape ? model.inputs[0].shape[1] : 0, model.inputs[0].shape ? model.inputs[0].shape[2] : 0], false);
tf.dispose(inputImage.tensor);
t.norm = tf.div(t.resize, 255);
t.norm = tf.div(t.resize, constants.tf255);
t.res = model.execute(t.norm) as Tensor;
t.squeeze = tf.squeeze(t.res, 0); // meet.shape:[1,256,256,1], selfie.shape:[1,144,256,2]

9
src/tfjs/constants.ts Normal file
View File

@ -0,0 +1,9 @@
import * as tf from '../../dist/tfjs.esm.js';
import type { Tensor } from './types';
export const tf255: Tensor = tf.scalar(255, 'float32');
export const tf1: Tensor = tf.scalar(1, 'float32');
export const tf2: Tensor = tf.scalar(2, 'float32');
export const tf05: Tensor = tf.scalar(0.5, 'float32');
export const tf127: Tensor = tf.scalar(127.5, 'float32');
export const rgb: Tensor = tf.tensor1d([0.2989, 0.5870, 0.1140], 'float32'); // factors for red/green/blue colors when converting to grayscale