From 2f6524722e98c8da362194dbf466f3f6522fb309 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 9 Nov 2021 19:39:18 -0500 Subject: [PATCH] auto tensor shape and channels handling --- src/image/image.ts | 34 +++++++++++++++++++------- test/test-main.js | 59 +++++++++++++++++++++++++++++++++++----------- 2 files changed, 71 insertions(+), 22 deletions(-) diff --git a/src/image/image.ts b/src/image/image.ts index ba149600..d36ef726 100644 --- a/src/image/image.ts +++ b/src/image/image.ts @@ -77,15 +77,33 @@ export async function process(input: Input, config: Config, getTensor: boolean = ) { throw new Error('input type is not recognized'); } - if (input instanceof tf.Tensor) { - // if input is tensor, use as-is - if ((input)['isDisposedInternal']) { - throw new Error('input tensor is disposed'); - } else if (!(input as Tensor).shape || (input as Tensor).shape.length !== 4 || (input as Tensor).shape[0] !== 1 || (input as Tensor).shape[3] !== 3) { - throw new Error('input tensor shape must be [1, height, width, 3] and instead was' + (input['shape'] ? input['shape'].toString() : 'unknown')); - } else { - return { tensor: tf.clone(input), canvas: (config.filter.return ? outCanvas : null) }; + if (input instanceof tf.Tensor) { // if input is tensor use as-is without filters but correct shape as needed + let tensor: Tensor | null = null; + if ((input as Tensor)['isDisposedInternal']) throw new Error('input tensor is disposed'); + if (!(input as Tensor)['shape']) throw new Error('input tensor has no shape'); + if ((input as Tensor).shape.length === 3) { // [height, width, 3 || 4] + if ((input as Tensor).shape[2] === 3) { // [height, width, 3] so add batch + tensor = tf.expandDims(input, 0); + } else if ((input as Tensor).shape[2] === 4) { // [height, width, 4] so strip alpha and add batch + const rgb = tf.slice3d(input, [0, 0, 0], [-1, -1, 3]); + tensor = tf.expandDims(rgb, 0); + tf.dispose(rgb); + } + } else if ((input as Tensor).shape.length === 4) { // [1, width, height, 3 || 4] + if ((input as Tensor).shape[3] === 3) { // [1, width, height, 3] just clone + tensor = tf.clone(input); + } else if ((input as Tensor).shape[3] === 4) { // [1, width, height, 4] so strip alpha + tensor = tf.slice4d(input, [0, 0, 0, 0], [-1, -1, -1, 3]); + } } + // at the end shape must be [1, height, width, 3] + if (tensor == null || tensor.shape.length !== 4 || tensor.shape[0] !== 1 || tensor.shape[3] !== 3) throw new Error(`could not process input tensor with shape: ${input['shape']}`); + if ((tensor as Tensor).dtype === 'int32') { + const cast = tf.cast(tensor, 'float32'); + tf.dispose(tensor); + tensor = cast; + } + return { tensor, canvas: (config.filter.return ? outCanvas : null) }; } else { // check if resizing will be needed if (typeof input['readyState'] !== 'undefined' && input['readyState'] <= 2) { diff --git a/test/test-main.js b/test/test-main.js index ad7023e5..fba8a471 100644 --- a/test/test-main.js +++ b/test/test-main.js @@ -27,7 +27,7 @@ async function testHTTP() { }); } -async function getImage(human, input) { +async function getImage(human, input, options = { channels: 3, expand: true, cast: true }) { let img; try { img = await canvasJS.loadImage(input); @@ -39,18 +39,15 @@ async function getImage(human, input) { const ctx = canvas.getContext('2d'); ctx.drawImage(img, 0, 0, img.width, img.height); const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); - const res = human.tf.tidy(() => { - const tensor = human.tf.tensor(Array.from(imageData.data), [canvas.height, canvas.width, 4], 'float32'); // create rgba image tensor from flat array - const channels = human.tf.split(tensor, 4, 2); // split rgba to channels - const rgb = human.tf.stack([channels[0], channels[1], channels[2]], 2); // stack channels back to rgb - const reshape = human.tf.reshape(rgb, [1, canvas.height, canvas.width, 3]); // move extra dim from the end of tensor and use it as batch number instead - return reshape; - }); - const sum = human.tf.sum(res); - if (res && res.shape[0] === 1 && res.shape[3] === 3) log('state', 'passed: load image:', input, res.shape, { checksum: sum.dataSync()[0] }); - else log('error', 'failed: load image:', input, res); + + const data = human.tf.tensor(Array.from(imageData.data), [canvas.height, canvas.width, 4], options.cast ? 'float32' : 'int32'); // create rgba image tensor from flat array + const channels = options.channels === 3 ? human.tf.slice3d(data, [0, 0, 0], [-1, -1, 3]) : human.tf.clone(data); // optionally strip alpha channel + const tensor = options.expand ? human.tf.expandDims(channels, 0) : human.tf.clone(channels); // optionally add batch num dimension + human.tf.dispose([data, channels]); + const sum = human.tf.sum(tensor); + log('state', 'passed: load image:', input, tensor.shape, { checksum: sum.dataSync()[0] }); human.tf.dispose(sum); - return res; + return tensor; } function printResults(detect) { @@ -96,8 +93,6 @@ async function testWarmup(human, title) { } if (warmup) { log('state', 'passed: warmup:', config.warmup, title); - // const count = human.tf.engine().state.numTensors; - // if (count - tensors > 0) log('warn', 'failed: memory', config.warmup, title, 'tensors:', count - tensors); printResults(warmup); } else { log('error', 'failed: warmup:', config.warmup, title); @@ -176,6 +171,41 @@ async function verifyDetails(human) { } } +async function testTensorShapes(human, input) { + await human.load(config); + const numTensors = human.tf.engine().state.numTensors; + let res; + let tensor; + + tensor = await getImage(human, input, { channels: 4, expand: true, cast: true }); + res = await human.detect(tensor, config); + verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype); + human.tf.dispose(tensor); + + tensor = await getImage(human, input, { channels: 4, expand: false, cast: true }); + res = await human.detect(tensor, config); + verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype); + human.tf.dispose(tensor); + + tensor = await getImage(human, input, { channels: 3, expand: true, cast: true }); + res = await human.detect(tensor, config); + verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype); + human.tf.dispose(tensor); + + tensor = await getImage(human, input, { channels: 3, expand: false, cast: true }); + res = await human.detect(tensor, config); + verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype); + human.tf.dispose(tensor); + + tensor = await getImage(human, input, { channels: 4, expand: true, cast: false }); + res = await human.detect(tensor, config); + verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype); + human.tf.dispose(tensor); + + const leak = human.tf.engine().state.numTensors - numTensors; + if (leak !== 0) log('error', 'failed: memory leak', leak); +} + async function verifyCompare(human) { log('info', 'test: input compare'); const t1 = await getImage(human, 'samples/in/ai-face.jpg'); @@ -253,6 +283,7 @@ async function test(Human, inputConfig) { }); await verifyDetails(human); + await testTensorShapes(human, 'samples/in/ai-body.jpg'); // test default config async log('info', 'test default');