face-api/build/dom/toNetInput.js

44 lines
2.1 KiB
JavaScript
Raw Normal View History

2020-08-20 02:10:42 +02:00
import { isTensor3D, isTensor4D } from '../utils';
import { awaitMediaLoaded } from './awaitMediaLoaded';
import { isMediaElement } from './isMediaElement';
import { NetInput } from './NetInput';
import { resolveInput } from './resolveInput';
2020-08-18 14:04:33 +02:00
/**
* Validates the input to make sure, they are valid net inputs and awaits all media elements
* to be finished loading.
*
* @param input The input, which can be a media element or an array of different media elements.
* @returns A NetInput instance, which can be passed into one of the neural networks.
*/
2020-08-20 02:10:42 +02:00
export async function toNetInput(inputs) {
if (inputs instanceof NetInput) {
2020-08-18 14:04:33 +02:00
return inputs;
}
let inputArgArray = Array.isArray(inputs)
? inputs
: [inputs];
if (!inputArgArray.length) {
throw new Error('toNetInput - empty array passed as input');
}
const getIdxHint = (idx) => Array.isArray(inputs) ? ` at input index ${idx}:` : '';
2020-08-20 02:10:42 +02:00
const inputArray = inputArgArray.map(resolveInput);
2020-08-18 14:04:33 +02:00
inputArray.forEach((input, i) => {
2020-08-20 02:10:42 +02:00
if (!isMediaElement(input) && !isTensor3D(input) && !isTensor4D(input)) {
2020-08-18 14:04:33 +02:00
if (typeof inputArgArray[i] === 'string') {
throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`);
}
throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`);
}
2020-08-20 02:10:42 +02:00
if (isTensor4D(input)) {
2020-08-18 14:04:33 +02:00
// if tf.Tensor4D is passed in the input array, the batch size has to be 1
const batchSize = input.shape[0];
if (batchSize !== 1) {
throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
}
}
});
// wait for all media elements being loaded
2020-08-20 02:10:42 +02:00
await Promise.all(inputArray.map(input => isMediaElement(input) && awaitMediaLoaded(input)));
return new NetInput(inputArray, Array.isArray(inputs));
2020-08-18 14:04:33 +02:00
}
//# sourceMappingURL=toNetInput.js.map