human/src/segmentation/segmentation.ts

90 lines
3.8 KiB
TypeScript
Raw Normal View History

2021-06-04 19:51:01 +02:00
/**
* EfficientPose Module
*/
import { log, join } from '../helpers';
import * as tf from '../../dist/tfjs.esm.js';
import { GraphModel, Tensor } from '../tfjs/types';
import { Config } from '../config';
// import * as blur from './blur';
let model: GraphModel;
// let blurKernel;
export type Segmentation = boolean;
export async function load(config: Config): Promise<GraphModel> {
if (!model) {
// @ts-ignore type mismatch on GraphModel
model = await tf.loadGraphModel(join(config.modelBasePath, config.segmentation.modelPath));
if (!model || !model['modelUrl']) log('load model failed:', config.segmentation.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
2021-06-05 02:22:05 +02:00
// if (!blurKernel) blurKernel = blur.getGaussianKernel(5, 1, 1);
2021-06-04 19:51:01 +02:00
return model;
}
export async function predict(input: { tensor: Tensor | null, canvas: OffscreenCanvas | HTMLCanvasElement }, config: Config): Promise<Segmentation> {
if (!config.segmentation.enabled || !input.tensor || !input.canvas) return false;
if (!model || !model.inputs[0].shape) return false;
const resizeInput = tf.image.resizeBilinear(input.tensor, [model.inputs[0].shape[1], model.inputs[0].shape[2]], false);
const norm = resizeInput.div(255);
const res = model.predict(norm) as Tensor;
2021-06-05 02:22:05 +02:00
// meet output: 1,256,256,1
// selfie output: 1,144,256,2
2021-06-04 19:51:01 +02:00
tf.dispose(resizeInput);
tf.dispose(norm);
const overlay = (typeof OffscreenCanvas !== 'undefined') ? new OffscreenCanvas(input.canvas.width, input.canvas.height) : document.createElement('canvas');
overlay.width = input.canvas.width;
overlay.height = input.canvas.height;
const squeeze = tf.squeeze(res, 0);
let resizeOutput;
2021-06-05 02:22:05 +02:00
if (squeeze.shape[2] === 2) {
// model meet has two channels for fg and bg
2021-06-04 19:51:01 +02:00
const softmax = squeeze.softmax();
const [bg, fg] = tf.unstack(softmax, 2);
const expand = fg.expandDims(2);
2021-06-05 02:22:05 +02:00
const pad = expand.expandDims(0);
tf.dispose(softmax);
2021-06-04 19:51:01 +02:00
tf.dispose(bg);
tf.dispose(fg);
2021-06-05 02:22:05 +02:00
// running sofmax before unstack creates 2x2 matrix so we only take upper-left quadrant
const crop = tf.image.cropAndResize(pad, [[0, 0, 0.5, 0.5]], [0], [input.tensor?.shape[1], input.tensor?.shape[2]]);
// otherwise run softmax after unstack and use standard resize
// resizeOutput = tf.image.resizeBilinear(expand, [input.tensor?.shape[1], input.tensor?.shape[2]]);
resizeOutput = crop.squeeze(0);
tf.dispose(crop);
2021-06-04 19:51:01 +02:00
tf.dispose(expand);
2021-06-05 02:22:05 +02:00
tf.dispose(pad);
} else { // model selfie has a single channel that we can use directly
2021-06-04 19:51:01 +02:00
resizeOutput = tf.image.resizeBilinear(squeeze, [input.tensor?.shape[1], input.tensor?.shape[2]]);
}
// const blurred = blur.blur(resizeOutput, blurKernel);
if (tf.browser) await tf.browser.toPixels(resizeOutput, overlay);
// tf.dispose(blurred);
tf.dispose(resizeOutput);
tf.dispose(squeeze);
tf.dispose(res);
2021-06-05 02:22:05 +02:00
const original = (typeof OffscreenCanvas !== 'undefined') ? new OffscreenCanvas(input.canvas.width, input.canvas.height) : document.createElement('canvas'); // need one more copy since input may already have gl context so 2d context fails
original.width = input.canvas.width;
original.height = input.canvas.height;
const ctx = original.getContext('2d') as CanvasRenderingContext2D;
await ctx.drawImage(input.canvas, 0, 0);
2021-06-04 19:51:01 +02:00
// https://developer.mozilla.org/en-US/docs/Web/API/CanvasRenderingContext2D/globalCompositeOperation
// best options are: darken, color-burn, multiply
ctx.globalCompositeOperation = 'darken';
2021-06-05 02:22:05 +02:00
ctx.filter = 'blur(8px)'; // use css filter for bluring, can be done with gaussian blur manually instead
await ctx.drawImage(overlay, 0, 0);
ctx.globalCompositeOperation = 'source-in'; // reset
ctx.filter = 'none'; // reset
input.canvas = original;
2021-06-04 19:51:01 +02:00
return true;
}