human/src/segmentation/segmentation.ts

99 lines
5.2 KiB
TypeScript
Raw Normal View History

2021-06-04 19:51:01 +02:00
/**
* Image segmentation for body detection model
*
* Based on:
* - [**MediaPipe Meet**](https://drive.google.com/file/d/1lnP1bRi9CSqQQXUHa13159vLELYDgDu0/preview)
* - [**MediaPipe Selfie**](https://drive.google.com/file/d/1dCfozqknMa068vVsO2j_1FgZkW_e3VWv/preview)
2021-06-04 19:51:01 +02:00
*/
2022-01-17 17:03:21 +01:00
import { log } from '../util/util';
2021-06-04 19:51:01 +02:00
import * as tf from '../../dist/tfjs.esm.js';
2022-01-16 15:49:55 +01:00
import { loadModel } from '../tfjs/load';
import * as image from '../image/image';
2021-11-17 02:16:49 +01:00
import { constants } from '../tfjs/constants';
2021-09-13 19:28:35 +02:00
import type { GraphModel, Tensor } from '../tfjs/types';
import type { Config } from '../config';
2021-09-27 19:58:13 +02:00
import { env } from '../util/env';
2021-11-12 21:07:23 +01:00
import type { Input, AnyCanvas } from '../exports';
2021-06-04 19:51:01 +02:00
let model: GraphModel;
let busy = false;
2021-06-04 19:51:01 +02:00
export async function load(config: Config): Promise<GraphModel> {
2022-01-17 17:03:21 +01:00
if (!model || env.initial) model = await loadModel(config.segmentation.modelPath);
else if (config.debug) log('cached model:', model['modelUrl']);
2021-06-04 19:51:01 +02:00
return model;
}
2021-09-23 01:27:12 +02:00
export async function process(input: Input, background: Input | undefined, config: Config)
2022-08-21 19:34:51 +02:00
: Promise<{ data: number[] | Tensor, canvas: AnyCanvas | null, alpha: AnyCanvas | null }> {
2021-09-23 01:27:12 +02:00
if (busy) return { data: [], canvas: null, alpha: null };
busy = true;
if (!model) await load(config);
2021-11-06 15:21:51 +01:00
const inputImage = await image.process(input, config);
2021-11-12 21:07:23 +01:00
const width = inputImage.tensor?.shape[2] || 0;
const height = inputImage.tensor?.shape[1] || 0;
2021-09-23 01:27:12 +02:00
if (!inputImage.tensor) return { data: [], canvas: null, alpha: null };
const t: Record<string, Tensor> = {};
2021-06-04 19:51:01 +02:00
2021-09-23 01:27:12 +02:00
t.resize = tf.image.resizeBilinear(inputImage.tensor, [model.inputs[0].shape ? model.inputs[0].shape[1] : 0, model.inputs[0].shape ? model.inputs[0].shape[2] : 0], false);
tf.dispose(inputImage.tensor);
2021-11-17 00:31:07 +01:00
t.norm = tf.div(t.resize, constants.tf255);
2021-11-02 16:07:11 +01:00
t.res = model.execute(t.norm) as Tensor;
2021-09-23 01:27:12 +02:00
t.squeeze = tf.squeeze(t.res, 0); // meet.shape:[1,256,256,1], selfie.shape:[1,144,256,2]
if (t.squeeze.shape[2] === 2) {
t.softmax = tf.softmax(t.squeeze); // model meet has two channels for fg and bg
[t.bg, t.fg] = tf.unstack(t.softmax, 2);
t.expand = tf.expandDims(t.fg, 2);
t.pad = tf.expandDims(t.expand, 0);
t.crop = tf.image.cropAndResize(t.pad, [[0, 0, 0.5, 0.5]], [0], [width, height]);
2021-06-05 02:22:05 +02:00
// running sofmax before unstack creates 2x2 matrix so we only take upper-left quadrant
// otherwise run softmax after unstack and use standard resize
// resizeOutput = tf.image.resizeBilinear(expand, [input.tensor?.shape[1], input.tensor?.shape[2]]);
2021-09-23 01:27:12 +02:00
t.data = tf.squeeze(t.crop, 0);
} else {
t.data = tf.image.resizeBilinear(t.squeeze, [height, width]); // model selfie has a single channel that we can use directly
2021-06-04 19:51:01 +02:00
}
2022-08-21 19:34:51 +02:00
const data = Array.from(await t.data.data());
2021-06-04 19:51:01 +02:00
2021-09-23 01:27:12 +02:00
if (env.node && !env.Canvas && (typeof ImageData === 'undefined')) {
if (config.debug) log('canvas support missing');
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
2021-09-22 21:16:14 +02:00
return { data, canvas: null, alpha: null }; // running in nodejs so return alpha array as-is
2021-09-13 19:28:35 +02:00
}
2021-06-05 23:51:46 +02:00
2021-09-20 15:42:34 +02:00
const alphaCanvas = image.canvas(width, height);
2021-12-28 17:39:54 +01:00
if (tf.browser) await tf.browser.toPixels(t.data, alphaCanvas);
2021-09-22 21:16:14 +02:00
const alphaCtx = alphaCanvas.getContext('2d') as CanvasRenderingContext2D;
if (config.segmentation.blur && config.segmentation.blur > 0) alphaCtx.filter = `blur(${config.segmentation.blur}px)`; // use css filter for bluring, can be done with gaussian blur manually instead
const alphaData = alphaCtx.getImageData(0, 0, width, height);
const compositeCanvas = image.canvas(width, height);
const compositeCtx = compositeCanvas.getContext('2d') as CanvasRenderingContext2D;
2021-09-23 01:27:12 +02:00
if (inputImage.canvas) compositeCtx.drawImage(inputImage.canvas, 0, 0);
2021-09-22 21:16:14 +02:00
compositeCtx.globalCompositeOperation = 'darken'; // https://developer.mozilla.org/en-US/docs/Web/API/CanvasRenderingContext2D/globalCompositeOperation // best options are: darken, color-burn, multiply
if (config.segmentation.blur && config.segmentation.blur > 0) compositeCtx.filter = `blur(${config.segmentation.blur}px)`; // use css filter for bluring, can be done with gaussian blur manually instead
compositeCtx.drawImage(alphaCanvas, 0, 0);
compositeCtx.globalCompositeOperation = 'source-over'; // reset composite operation
compositeCtx.filter = 'none'; // reset css filter
const compositeData = compositeCtx.getImageData(0, 0, width, height);
for (let i = 0; i < width * height; i++) compositeData.data[4 * i + 3] = alphaData.data[4 * i + 0]; // copy original alpha value to new composite canvas
compositeCtx.putImageData(compositeData, 0, 0);
2021-11-14 17:22:52 +01:00
let mergedCanvas: AnyCanvas | null = null;
2021-09-23 01:27:12 +02:00
if (background && compositeCanvas) { // draw background with segmentation as overlay if background is present
mergedCanvas = image.canvas(width, height);
2021-11-06 15:21:51 +01:00
const bgImage = await image.process(background, config);
2021-09-22 21:16:14 +02:00
tf.dispose(bgImage.tensor);
const ctxMerge = mergedCanvas.getContext('2d') as CanvasRenderingContext2D;
ctxMerge.drawImage(bgImage.canvas as HTMLCanvasElement, 0, 0, mergedCanvas.width, mergedCanvas.height);
2021-09-23 01:27:12 +02:00
ctxMerge.drawImage(compositeCanvas, 0, 0);
2021-09-20 15:42:34 +02:00
}
2021-09-23 01:27:12 +02:00
Object.keys(t).forEach((tensor) => tf.dispose(t[tensor]));
busy = false;
2021-11-12 21:07:23 +01:00
// return { data, canvas: mergedCanvas || compositeCanvas, alpha: alphaCanvas };
return { data, canvas: compositeCanvas, alpha: alphaCanvas };
2021-06-04 19:51:01 +02:00
}