mirror of https://github.com/vladmandic/human
auto tensor shape and channels handling
parent
1729a989af
commit
adb358fe98
|
@ -77,15 +77,33 @@ export async function process(input: Input, config: Config, getTensor: boolean =
|
|||
) {
|
||||
throw new Error('input type is not recognized');
|
||||
}
|
||||
if (input instanceof tf.Tensor) {
|
||||
// if input is tensor, use as-is
|
||||
if ((input)['isDisposedInternal']) {
|
||||
throw new Error('input tensor is disposed');
|
||||
} else if (!(input as Tensor).shape || (input as Tensor).shape.length !== 4 || (input as Tensor).shape[0] !== 1 || (input as Tensor).shape[3] !== 3) {
|
||||
throw new Error('input tensor shape must be [1, height, width, 3] and instead was' + (input['shape'] ? input['shape'].toString() : 'unknown'));
|
||||
} else {
|
||||
return { tensor: tf.clone(input), canvas: (config.filter.return ? outCanvas : null) };
|
||||
if (input instanceof tf.Tensor) { // if input is tensor use as-is without filters but correct shape as needed
|
||||
let tensor: Tensor | null = null;
|
||||
if ((input as Tensor)['isDisposedInternal']) throw new Error('input tensor is disposed');
|
||||
if (!(input as Tensor)['shape']) throw new Error('input tensor has no shape');
|
||||
if ((input as Tensor).shape.length === 3) { // [height, width, 3 || 4]
|
||||
if ((input as Tensor).shape[2] === 3) { // [height, width, 3] so add batch
|
||||
tensor = tf.expandDims(input, 0);
|
||||
} else if ((input as Tensor).shape[2] === 4) { // [height, width, 4] so strip alpha and add batch
|
||||
const rgb = tf.slice3d(input, [0, 0, 0], [-1, -1, 3]);
|
||||
tensor = tf.expandDims(rgb, 0);
|
||||
tf.dispose(rgb);
|
||||
}
|
||||
} else if ((input as Tensor).shape.length === 4) { // [1, width, height, 3 || 4]
|
||||
if ((input as Tensor).shape[3] === 3) { // [1, width, height, 3] just clone
|
||||
tensor = tf.clone(input);
|
||||
} else if ((input as Tensor).shape[3] === 4) { // [1, width, height, 4] so strip alpha
|
||||
tensor = tf.slice4d(input, [0, 0, 0, 0], [-1, -1, -1, 3]);
|
||||
}
|
||||
}
|
||||
// at the end shape must be [1, height, width, 3]
|
||||
if (tensor == null || tensor.shape.length !== 4 || tensor.shape[0] !== 1 || tensor.shape[3] !== 3) throw new Error(`could not process input tensor with shape: ${input['shape']}`);
|
||||
if ((tensor as Tensor).dtype === 'int32') {
|
||||
const cast = tf.cast(tensor, 'float32');
|
||||
tf.dispose(tensor);
|
||||
tensor = cast;
|
||||
}
|
||||
return { tensor, canvas: (config.filter.return ? outCanvas : null) };
|
||||
} else {
|
||||
// check if resizing will be needed
|
||||
if (typeof input['readyState'] !== 'undefined' && input['readyState'] <= 2) {
|
||||
|
|
|
@ -27,7 +27,7 @@ async function testHTTP() {
|
|||
});
|
||||
}
|
||||
|
||||
async function getImage(human, input) {
|
||||
async function getImage(human, input, options = { channels: 3, expand: true, cast: true }) {
|
||||
let img;
|
||||
try {
|
||||
img = await canvasJS.loadImage(input);
|
||||
|
@ -39,18 +39,15 @@ async function getImage(human, input) {
|
|||
const ctx = canvas.getContext('2d');
|
||||
ctx.drawImage(img, 0, 0, img.width, img.height);
|
||||
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
|
||||
const res = human.tf.tidy(() => {
|
||||
const tensor = human.tf.tensor(Array.from(imageData.data), [canvas.height, canvas.width, 4], 'float32'); // create rgba image tensor from flat array
|
||||
const channels = human.tf.split(tensor, 4, 2); // split rgba to channels
|
||||
const rgb = human.tf.stack([channels[0], channels[1], channels[2]], 2); // stack channels back to rgb
|
||||
const reshape = human.tf.reshape(rgb, [1, canvas.height, canvas.width, 3]); // move extra dim from the end of tensor and use it as batch number instead
|
||||
return reshape;
|
||||
});
|
||||
const sum = human.tf.sum(res);
|
||||
if (res && res.shape[0] === 1 && res.shape[3] === 3) log('state', 'passed: load image:', input, res.shape, { checksum: sum.dataSync()[0] });
|
||||
else log('error', 'failed: load image:', input, res);
|
||||
|
||||
const data = human.tf.tensor(Array.from(imageData.data), [canvas.height, canvas.width, 4], options.cast ? 'float32' : 'int32'); // create rgba image tensor from flat array
|
||||
const channels = options.channels === 3 ? human.tf.slice3d(data, [0, 0, 0], [-1, -1, 3]) : human.tf.clone(data); // optionally strip alpha channel
|
||||
const tensor = options.expand ? human.tf.expandDims(channels, 0) : human.tf.clone(channels); // optionally add batch num dimension
|
||||
human.tf.dispose([data, channels]);
|
||||
const sum = human.tf.sum(tensor);
|
||||
log('state', 'passed: load image:', input, tensor.shape, { checksum: sum.dataSync()[0] });
|
||||
human.tf.dispose(sum);
|
||||
return res;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
function printResults(detect) {
|
||||
|
@ -96,8 +93,6 @@ async function testWarmup(human, title) {
|
|||
}
|
||||
if (warmup) {
|
||||
log('state', 'passed: warmup:', config.warmup, title);
|
||||
// const count = human.tf.engine().state.numTensors;
|
||||
// if (count - tensors > 0) log('warn', 'failed: memory', config.warmup, title, 'tensors:', count - tensors);
|
||||
printResults(warmup);
|
||||
} else {
|
||||
log('error', 'failed: warmup:', config.warmup, title);
|
||||
|
@ -176,6 +171,41 @@ async function verifyDetails(human) {
|
|||
}
|
||||
}
|
||||
|
||||
async function testTensorShapes(human, input) {
|
||||
await human.load(config);
|
||||
const numTensors = human.tf.engine().state.numTensors;
|
||||
let res;
|
||||
let tensor;
|
||||
|
||||
tensor = await getImage(human, input, { channels: 4, expand: true, cast: true });
|
||||
res = await human.detect(tensor, config);
|
||||
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
|
||||
human.tf.dispose(tensor);
|
||||
|
||||
tensor = await getImage(human, input, { channels: 4, expand: false, cast: true });
|
||||
res = await human.detect(tensor, config);
|
||||
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
|
||||
human.tf.dispose(tensor);
|
||||
|
||||
tensor = await getImage(human, input, { channels: 3, expand: true, cast: true });
|
||||
res = await human.detect(tensor, config);
|
||||
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
|
||||
human.tf.dispose(tensor);
|
||||
|
||||
tensor = await getImage(human, input, { channels: 3, expand: false, cast: true });
|
||||
res = await human.detect(tensor, config);
|
||||
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
|
||||
human.tf.dispose(tensor);
|
||||
|
||||
tensor = await getImage(human, input, { channels: 4, expand: true, cast: false });
|
||||
res = await human.detect(tensor, config);
|
||||
verify(res.face.length === 1 && res.face[0].gender === 'female', 'tensor shape:', tensor.shape, 'dtype:', tensor.dtype);
|
||||
human.tf.dispose(tensor);
|
||||
|
||||
const leak = human.tf.engine().state.numTensors - numTensors;
|
||||
if (leak !== 0) log('error', 'failed: memory leak', leak);
|
||||
}
|
||||
|
||||
async function verifyCompare(human) {
|
||||
log('info', 'test: input compare');
|
||||
const t1 = await getImage(human, 'samples/in/ai-face.jpg');
|
||||
|
@ -253,6 +283,7 @@ async function test(Human, inputConfig) {
|
|||
});
|
||||
|
||||
await verifyDetails(human);
|
||||
await testTensorShapes(human, 'samples/in/ai-body.jpg');
|
||||
|
||||
// test default config async
|
||||
log('info', 'test default');
|
||||
|
|
Loading…
Reference in New Issue