auto tensor shape and channels handling

pull/356/head
Vladimir Mandic 2021-11-09 19:39:18 -05:00
parent 1729a989af
commit adb358fe98
2 changed files with 71 additions and 22 deletions

View File

@ -77,15 +77,33 @@ export async function process(input: Input, config: Config, getTensor: boolean =
) {
throw new Error('input type is not recognized');
}
if (input instanceof tf.Tensor) {
// if input is tensor, use as-is
if ((input)['isDisposedInternal']) {
throw new Error('input tensor is disposed');
} else if (!(input as Tensor).shape || (input as Tensor).shape.length !== 4 || (input as Tensor).shape[0] !== 1 || (input as Tensor).shape[3] !== 3) {
throw new Error('input tensor shape must be [1, height, width, 3] and instead was' + (input['shape'] ? input['shape'].toString() : 'unknown'));
} else {
return { tensor: tf.clone(input), canvas: (config.filter.return ? outCanvas : null) };
if (input instanceof tf.Tensor) { // if input is tensor use as-is without filters but correct shape as needed
let tensor: Tensor | null = null;
if ((input as Tensor)['isDisposedInternal']) throw new Error('input tensor is disposed');
if (!(input as Tensor)['shape']) throw new Error('input tensor has no shape');
if ((input as Tensor).shape.length === 3) { // [height, width, 3 || 4]
if ((input as Tensor).shape[2] === 3) { // [height, width, 3] so add batch
tensor = tf.expandDims(input, 0);
} else if ((input as Tensor).shape[2] === 4) { // [height, width, 4] so strip alpha and add batch
const rgb = tf.slice3d(input, [0, 0, 0], [-1, -1, 3]);
tensor = tf.expandDims(rgb, 0);
tf.dispose(rgb);
}
} else if ((input as Tensor).shape.length === 4) { // [1, width, height, 3 || 4]
if ((input as Tensor).shape[3] === 3) { // [1, width, height, 3] just clone
tensor = tf.clone(input);
} else if ((input as Tensor).shape[3] === 4) { // [1, width, height, 4] so strip alpha
tensor = tf.slice4d(input, [0, 0, 0, 0], [-1, -1, -1, 3]);
}
}
// at the end shape must be [1, height, width, 3]
if (tensor == null || tensor.shape.length !== 4 || tensor.shape[0] !== 1 || tensor.shape[3] !== 3) throw new Error(`could not process input tensor with shape: ${input['shape']}`);
if ((tensor as Tensor).dtype === 'int32') {
const cast = tf.cast(tensor, 'float32');
tf.dispose(tensor);
tensor = cast;
}
return { tensor, canvas: (config.filter.return ? outCanvas : null) };
} else {
// check if resizing will be needed
if (typeof input['readyState'] !== 'undefined' && input['readyState'] <= 2) {

View File

@ -27,7 +27,7 @@ async function testHTTP() {
});
}
async function getImage(human, input) {
async function getImage(human, input, options = { channels: 3, expand: true, cast: true }) {
let img;
try {
img = await canvasJS.loadImage(input);
@ -39,18 +39,15 @@ async function getImage(human, input) {
const ctx = canvas.getContext('2d');
ctx.drawImage(img, 0, 0, img.width, img.height);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const res = human.tf.tidy(() => {
const tensor = human.tf.tensor(Array.from(imageData.data), [canvas.height, canvas.width, 4], 'float32'); // create rgba image tensor from flat array
const channels = human.tf.split(tensor, 4, 2); // split rgba to channels
const rgb = human.tf.stack([channels[0], channels[1], channels[2]], 2); // stack channels back to rgb
const reshape = human.tf.reshape(rgb, [1, canvas.height, canvas.width, 3]); // move extra dim from the end of tensor and use it as batch number instead
return reshape;
});
const sum = human.tf.sum(res);
if (res && res.shape[0] === 1 && res.shape[3] === 3) log('state', 'passed: load image:', input, res.shape, { checksum: sum.dataSync()[0] });
else log('error', 'failed: load image:', input, res);
const data = human.tf.tensor(Array.from(imageData.data), [canvas.height, canvas.width, 4], options.cast ? 'float32' : 'int32'); // create rgba image tensor from flat array
const channels = options.channels === 3 ? human.tf.slice3d(data, [0, 0, 0], [-1, -1, 3]) : human.tf.clone(data); // optionally strip alpha channel
const tensor = options.expand ? human.tf.expandDims(channels, 0) : human.tf.clone(channels); // optionally add batch num dimension
human.tf.dispose([data, channels]);
const sum = human.tf.sum(tensor);
log('state', 'passed: load image:', input, tensor.shape, { checksum: sum.dataSync()[0] });
human.tf.dispose(sum);
return res;
return tensor;
}
function printResults(detect) {
@ -96,8 +93,6 @@ async function testWarmup(human, title) {
}
if (warmup) {
log('state', 'passed: warmup:', config.warmup, title);
// const count = human.tf.engine().state.numTensors;
// if (count - tensors > 0) log('warn', 'failed: memory', config.warmup, title, 'tensors:', count - tensors);
printResults(warmup);
} else {
log('error', 'failed: warmup:', config.warmup, title);
@ -176,6 +171,41 @@ async function verifyDetails(human) {
}
}
async function testTensorShapes(human, input) {
await human.load(config);
const numTensors = human.tf.engine().state.numTensors;
let res;
let tensor;
tensor = await getImage(human, input, { channels: 4, expand: true, cast: true });
res = await human.detect(tensor, config);
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
human.tf.dispose(tensor);
tensor = await getImage(human, input, { channels: 4, expand: false, cast: true });
res = await human.detect(tensor, config);
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
human.tf.dispose(tensor);
tensor = await getImage(human, input, { channels: 3, expand: true, cast: true });
res = await human.detect(tensor, config);
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
human.tf.dispose(tensor);
tensor = await getImage(human, input, { channels: 3, expand: false, cast: true });
res = await human.detect(tensor, config);
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
human.tf.dispose(tensor);
tensor = await getImage(human, input, { channels: 4, expand: true, cast: false });
res = await human.detect(tensor, config);
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
human.tf.dispose(tensor);
const leak = human.tf.engine().state.numTensors - numTensors;
if (leak !== 0) log('error', 'failed: memory leak', leak);
}
async function verifyCompare(human) {
log('info', 'test: input compare');
const t1 = await getImage(human, 'samples/in/ai-face.jpg');
@ -253,6 +283,7 @@ async function test(Human, inputConfig) {
});
await verifyDetails(human);
await testTensorShapes(human, 'samples/in/ai-body.jpg');
// test default config async
log('info', 'test default');