face-api/build/dom/NetInput.js

117 lines
5.1 KiB
JavaScript
Raw Normal View History

"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.NetInput = void 0;
const tf = require("@tensorflow/tfjs-core");
const env_1 = require("../env");
const padToSquare_1 = require("../ops/padToSquare");
const utils_1 = require("../utils");
const createCanvas_1 = require("./createCanvas");
const imageToSquare_1 = require("./imageToSquare");
class NetInput {
2020-08-18 14:04:33 +02:00
constructor(inputs, treatAsBatchInput = false) {
this._imageTensors = [];
this._canvases = [];
this._treatAsBatchInput = false;
this._inputDimensions = [];
if (!Array.isArray(inputs)) {
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`);
}
this._treatAsBatchInput = treatAsBatchInput;
this._batchSize = inputs.length;
inputs.forEach((input, idx) => {
if (utils_1.isTensor3D(input)) {
2020-08-18 14:04:33 +02:00
this._imageTensors[idx] = input;
this._inputDimensions[idx] = input.shape;
return;
}
if (utils_1.isTensor4D(input)) {
2020-08-18 14:04:33 +02:00
const batchSize = input.shape[0];
if (batchSize !== 1) {
throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
}
this._imageTensors[idx] = input;
this._inputDimensions[idx] = input.shape.slice(1);
return;
}
const canvas = input instanceof env_1.env.getEnv().Canvas ? input : createCanvas_1.createCanvasFromMedia(input);
2020-08-18 14:04:33 +02:00
this._canvases[idx] = canvas;
this._inputDimensions[idx] = [canvas.height, canvas.width, 3];
});
}
get imageTensors() {
return this._imageTensors;
}
get canvases() {
return this._canvases;
}
get isBatchInput() {
return this.batchSize > 1 || this._treatAsBatchInput;
}
get batchSize() {
return this._batchSize;
}
get inputDimensions() {
return this._inputDimensions;
}
get inputSize() {
return this._inputSize;
}
get reshapedInputDimensions() {
return utils_1.range(this.batchSize, 0, 1).map((_, batchIdx) => this.getReshapedInputDimensions(batchIdx));
2020-08-18 14:04:33 +02:00
}
getInput(batchIdx) {
return this.canvases[batchIdx] || this.imageTensors[batchIdx];
}
getInputDimensions(batchIdx) {
return this._inputDimensions[batchIdx];
}
getInputHeight(batchIdx) {
return this._inputDimensions[batchIdx][0];
}
getInputWidth(batchIdx) {
return this._inputDimensions[batchIdx][1];
}
getReshapedInputDimensions(batchIdx) {
if (typeof this.inputSize !== 'number') {
throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet');
}
const width = this.getInputWidth(batchIdx);
const height = this.getInputHeight(batchIdx);
return utils_1.computeReshapedDimensions({ width, height }, this.inputSize);
2020-08-18 14:04:33 +02:00
}
/**
* Create a batch tensor from all input canvases and tensors
* with size [batchSize, inputSize, inputSize, 3].
*
* @param inputSize Height and width of the tensor.
* @param isCenterImage (optional, default: false) If true, add an equal amount of padding on
* both sides of the minor dimension oof the image.
* @returns The batch tensor.
*/
toBatchTensor(inputSize, isCenterInputs = true) {
this._inputSize = inputSize;
return tf.tidy(() => {
const inputTensors = utils_1.range(this.batchSize, 0, 1).map(batchIdx => {
2020-08-18 14:04:33 +02:00
const input = this.getInput(batchIdx);
if (input instanceof tf.Tensor) {
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
let imgTensor = utils_1.isTensor4D(input) ? input : input.expandDims();
2020-08-18 14:04:33 +02:00
// @ts-ignore: error TS2344: Type 'Rank.R4' does not satisfy the constraint 'Tensor<Rank>'.
imgTensor = padToSquare_1.padToSquare(imgTensor, isCenterInputs);
2020-08-18 14:04:33 +02:00
if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize]);
}
return imgTensor.as3D(inputSize, inputSize, 3);
}
if (input instanceof env_1.env.getEnv().Canvas) {
return tf.browser.fromPixels(imageToSquare_1.imageToSquare(input, inputSize, isCenterInputs));
2020-08-18 14:04:33 +02:00
}
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`);
});
const batchTensor = tf.stack(inputTensors.map(t => t.toFloat())).as4D(this.batchSize, inputSize, inputSize, 3);
return batchTensor;
});
}
}
exports.NetInput = NetInput;
2020-08-18 14:04:33 +02:00
//# sourceMappingURL=NetInput.js.map