human/src/segmentation/segmentation.ts

100 lines
5.3 KiB
TypeScript
Raw Normal View History

2021-06-04 19:51:01 +02:00
/**
* Image segmentation for body detection model
*
* Based on:
* - [**MediaPipe Meet**](https://drive.google.com/file/d/1lnP1bRi9CSqQQXUHa13159vLELYDgDu0/preview)
* - [**MediaPipe Selfie**](https://drive.google.com/file/d/1dCfozqknMa068vVsO2j_1FgZkW_e3VWv/preview)
2021-06-04 19:51:01 +02:00
*/
2021-09-27 19:58:13 +02:00
import { log, join } from '../util/util';
2021-06-04 19:51:01 +02:00
import * as tf from '../../dist/tfjs.esm.js';
import * as image from '../image/image';
2021-09-13 19:28:35 +02:00
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
2021-09-27 19:58:13 +02:00
import { env } from '../util/env';
2021-10-25 19:09:00 +02:00
import type { Input } from '../exports';
2021-06-04 19:51:01 +02:00
let model: GraphModel;
let busy = false;
2021-06-04 19:51:01 +02:00
export async function load(config: Config): Promise<GraphModel> {
2021-09-17 17:23:00 +02:00
if (!model || env.initial) {
2021-09-12 05:54:35 +02:00
model = await tf.loadGraphModel(join(config.modelBasePath, config.segmentation.modelPath || '')) as unknown as GraphModel;
2021-06-04 19:51:01 +02:00
if (!model || !model['modelUrl']) log('load model failed:', config.segmentation.modelPath);
else if (config.debug) log('load model:', model['modelUrl']);
} else if (config.debug) log('cached model:', model['modelUrl']);
return model;
}
2021-09-23 01:27:12 +02:00
export async function process(input: Input, background: Input | undefined, config: Config)
: Promise<{ data: Array<number>, canvas: HTMLCanvasElement | OffscreenCanvas | null, alpha: HTMLCanvasElement | OffscreenCanvas | null }> {
if (busy) return { data: [], canvas: null, alpha: null };
busy = true;
if (!model) await load(config);
const inputImage = image.process(input, config);
const width = inputImage.canvas?.width || 0;
const height = inputImage.canvas?.height || 0;
if (!inputImage.tensor) return { data: [], canvas: null, alpha: null };
const t: Record<string, Tensor> = {};
2021-06-04 19:51:01 +02:00
2021-09-23 01:27:12 +02:00
t.resize = tf.image.resizeBilinear(inputImage.tensor, [model.inputs[0].shape ? model.inputs[0].shape[1] : 0, model.inputs[0].shape ? model.inputs[0].shape[2] : 0], false);
tf.dispose(inputImage.tensor);
t.norm = tf.div(t.resize, 255);
t.res = model.predict(t.norm) as Tensor;
t.squeeze = tf.squeeze(t.res, 0); // meet.shape:[1,256,256,1], selfie.shape:[1,144,256,2]
if (t.squeeze.shape[2] === 2) {
t.softmax = tf.softmax(t.squeeze); // model meet has two channels for fg and bg
[t.bg, t.fg] = tf.unstack(t.softmax, 2);
t.expand = tf.expandDims(t.fg, 2);
t.pad = tf.expandDims(t.expand, 0);
t.crop = tf.image.cropAndResize(t.pad, [[0, 0, 0.5, 0.5]], [0], [width, height]);
2021-06-05 02:22:05 +02:00
// running sofmax before unstack creates 2x2 matrix so we only take upper-left quadrant
// otherwise run softmax after unstack and use standard resize
// resizeOutput = tf.image.resizeBilinear(expand, [input.tensor?.shape[1], input.tensor?.shape[2]]);
2021-09-23 01:27:12 +02:00
t.data = tf.squeeze(t.crop, 0);
} else {
t.data = tf.image.resizeBilinear(t.squeeze, [height, width]); // model selfie has a single channel that we can use directly
2021-06-04 19:51:01 +02:00
}
2021-10-22 19:49:40 +02:00
const data = Array.from(await t.data.data()) as number[];
2021-06-04 19:51:01 +02:00
2021-09-23 01:27:12 +02:00
if (env.node && !env.Canvas && (typeof ImageData === 'undefined')) {
if (config.debug) log('canvas support missing');
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
2021-09-22 21:16:14 +02:00
return { data, canvas: null, alpha: null }; // running in nodejs so return alpha array as-is
2021-09-13 19:28:35 +02:00
}
2021-06-05 23:51:46 +02:00
2021-09-20 15:42:34 +02:00
const alphaCanvas = image.canvas(width, height);
2021-09-23 01:27:12 +02:00
await tf.browser.toPixels(t.data, alphaCanvas);
2021-09-22 21:16:14 +02:00
const alphaCtx = alphaCanvas.getContext('2d') as CanvasRenderingContext2D;
if (config.segmentation.blur && config.segmentation.blur > 0) alphaCtx.filter = `blur(${config.segmentation.blur}px)`; // use css filter for bluring, can be done with gaussian blur manually instead
const alphaData = alphaCtx.getImageData(0, 0, width, height);
const compositeCanvas = image.canvas(width, height);
const compositeCtx = compositeCanvas.getContext('2d') as CanvasRenderingContext2D;
2021-09-23 01:27:12 +02:00
if (inputImage.canvas) compositeCtx.drawImage(inputImage.canvas, 0, 0);
2021-09-22 21:16:14 +02:00
compositeCtx.globalCompositeOperation = 'darken'; // https://developer.mozilla.org/en-US/docs/Web/API/CanvasRenderingContext2D/globalCompositeOperation // best options are: darken, color-burn, multiply
if (config.segmentation.blur && config.segmentation.blur > 0) compositeCtx.filter = `blur(${config.segmentation.blur}px)`; // use css filter for bluring, can be done with gaussian blur manually instead
compositeCtx.drawImage(alphaCanvas, 0, 0);
compositeCtx.globalCompositeOperation = 'source-over'; // reset composite operation
compositeCtx.filter = 'none'; // reset css filter
const compositeData = compositeCtx.getImageData(0, 0, width, height);
for (let i = 0; i < width * height; i++) compositeData.data[4 * i + 3] = alphaData.data[4 * i + 0]; // copy original alpha value to new composite canvas
compositeCtx.putImageData(compositeData, 0, 0);
let mergedCanvas: HTMLCanvasElement | OffscreenCanvas | null = null;
2021-09-23 01:27:12 +02:00
if (background && compositeCanvas) { // draw background with segmentation as overlay if background is present
mergedCanvas = image.canvas(width, height);
2021-09-22 21:16:14 +02:00
const bgImage = image.process(background, config);
tf.dispose(bgImage.tensor);
const ctxMerge = mergedCanvas.getContext('2d') as CanvasRenderingContext2D;
ctxMerge.drawImage(bgImage.canvas as HTMLCanvasElement, 0, 0, mergedCanvas.width, mergedCanvas.height);
2021-09-23 01:27:12 +02:00
ctxMerge.drawImage(compositeCanvas, 0, 0);
2021-09-20 15:42:34 +02:00
}
2021-09-23 01:27:12 +02:00
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
busy = false;
2021-09-23 01:27:12 +02:00
return { data, canvas: mergedCanvas || compositeCanvas, alpha: alphaCanvas };
2021-06-04 19:51:01 +02:00
}