human/dist/human.js

73182 lines
2.9 MiB

var Human = (() => {
var __defProp = Object.defineProperty;
var __markAsModule = (target) => __defProp(target, "__esModule", {value: true});
var __commonJS = (callback, module) => () => {
if (!module) {
module = {exports: {}};
callback(module.exports, module);
}
return module.exports;
};
var __export = (target, all) => {
__markAsModule(target);
for (var name in all)
__defProp(target, name, {get: all[name], enumerable: true});
};
// empty:/home/vlado/dev/human/node_modules/node-fetch/browser.js
var require_browser = __commonJS(() => {
});
// empty:util
var require_util = __commonJS(() => {
});
// empty:crypto
var require_crypto = __commonJS(() => {
});
// node_modules/@tensorflow/tfjs-core/dist/tf-core.node.js
var require_tf_core_node = __commonJS((exports) => {
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
"use strict";
Object.defineProperty(exports, "__esModule", {value: true});
/*! *****************************************************************************
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at http://www.apache.org/licenses/LICENSE-2.0
THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
MERCHANTABLITY OR NON-INFRINGEMENT.
See the Apache Version 2.0 License for specific language governing permissions
and limitations under the License.
***************************************************************************** */
var extendStatics = function(d, b) {
extendStatics = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(d2, b2) {
d2.__proto__ = b2;
} || function(d2, b2) {
for (var p in b2)
if (b2.hasOwnProperty(p))
d2[p] = b2[p];
};
return extendStatics(d, b);
};
function __extends(d, b) {
extendStatics(d, b);
function __() {
this.constructor = d;
}
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
}
function __awaiter(thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function(resolve, reject) {
function fulfilled(value) {
try {
step2(generator.next(value));
} catch (e) {
reject(e);
}
}
function rejected(value) {
try {
step2(generator["throw"](value));
} catch (e) {
reject(e);
}
}
function step2(result) {
result.done ? resolve(result.value) : new P(function(resolve2) {
resolve2(result.value);
}).then(fulfilled, rejected);
}
step2((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = {label: 0, sent: function() {
if (t[0] & 1)
throw t[1];
return t[1];
}, trys: [], ops: []}, f, y, t, g;
return g = {next: verb(0), throw: verb(1), return: verb(2)}, typeof Symbol === "function" && (g[Symbol.iterator] = function() {
return this;
}), g;
function verb(n) {
return function(v) {
return step2([n, v]);
};
}
function step2(op2) {
if (f)
throw new TypeError("Generator is already executing.");
while (_)
try {
if (f = 1, y && (t = op2[0] & 2 ? y["return"] : op2[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op2[1])).done)
return t;
if (y = 0, t)
op2 = [op2[0] & 2, t.value];
switch (op2[0]) {
case 0:
case 1:
t = op2;
break;
case 4:
_.label++;
return {value: op2[1], done: false};
case 5:
_.label++;
y = op2[1];
op2 = [0];
continue;
case 7:
op2 = _.ops.pop();
_.trys.pop();
continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op2[0] === 6 || op2[0] === 2)) {
_ = 0;
continue;
}
if (op2[0] === 3 && (!t || op2[1] > t[0] && op2[1] < t[3])) {
_.label = op2[1];
break;
}
if (op2[0] === 6 && _.label < t[1]) {
_.label = t[1];
t = op2;
break;
}
if (t && _.label < t[2]) {
_.label = t[2];
_.ops.push(op2);
break;
}
if (t[2])
_.ops.pop();
_.trys.pop();
continue;
}
op2 = body.call(thisArg, _);
} catch (e) {
op2 = [6, e];
y = 0;
} finally {
f = t = 0;
}
if (op2[0] & 5)
throw op2[1];
return {value: op2[0] ? op2[1] : void 0, done: true};
}
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EPSILON_FLOAT32 = 1e-7;
var EPSILON_FLOAT16 = 1e-4;
var DataStorage = function() {
function DataStorage2(backend2, dataMover) {
this.backend = backend2;
this.dataMover = dataMover;
this.data = new WeakMap();
this.dataIdsCount = 0;
}
DataStorage2.prototype.get = function(dataId) {
if (!this.data.has(dataId)) {
this.dataMover.moveData(this.backend, dataId);
}
return this.data.get(dataId);
};
DataStorage2.prototype.set = function(dataId, value) {
this.dataIdsCount++;
this.data.set(dataId, value);
};
DataStorage2.prototype.has = function(dataId) {
return this.data.has(dataId);
};
DataStorage2.prototype.delete = function(dataId) {
this.dataIdsCount--;
return this.data.delete(dataId);
};
DataStorage2.prototype.numDataIds = function() {
return this.dataIdsCount;
};
return DataStorage2;
}();
var KernelBackend = function() {
function KernelBackend2() {
}
KernelBackend2.prototype.time = function(f) {
return notYetImplemented("time");
};
KernelBackend2.prototype.read = function(dataId) {
return notYetImplemented("read");
};
KernelBackend2.prototype.readSync = function(dataId) {
return notYetImplemented("readSync");
};
KernelBackend2.prototype.numDataIds = function() {
return notYetImplemented("numDataIds");
};
KernelBackend2.prototype.disposeData = function(dataId) {
return notYetImplemented("disposeData");
};
KernelBackend2.prototype.write = function(values, shape, dtype) {
return notYetImplemented("write");
};
KernelBackend2.prototype.move = function(dataId, values, shape, dtype) {
return notYetImplemented("move");
};
KernelBackend2.prototype.memory = function() {
return notYetImplemented("memory");
};
KernelBackend2.prototype.floatPrecision = function() {
return notYetImplemented("floatPrecision");
};
KernelBackend2.prototype.epsilon = function() {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
};
KernelBackend2.prototype.batchMatMul = function(a, b, transposeA, transposeB) {
return notYetImplemented("batchMatMul");
};
KernelBackend2.prototype.fusedBatchMatMul = function(_a) {
var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
return notYetImplemented("fusedBatchMatMul");
};
KernelBackend2.prototype.slice = function(x, begin, size) {
return notYetImplemented("slice");
};
KernelBackend2.prototype.stridedSlice = function(x, begin, end, strides) {
return notYetImplemented("stridedSlice");
};
KernelBackend2.prototype.unstack = function(x, axis) {
return notYetImplemented("unstack");
};
KernelBackend2.prototype.reverse = function(a, axis) {
return notYetImplemented("reverse");
};
KernelBackend2.prototype.concat = function(tensors, axis) {
return notYetImplemented("concat");
};
KernelBackend2.prototype.neg = function(a) {
return notYetImplemented("neg");
};
KernelBackend2.prototype.add = function(a, b) {
return notYetImplemented("add");
};
KernelBackend2.prototype.addN = function(tensors) {
return notYetImplemented("addN");
};
KernelBackend2.prototype.subtract = function(a, b) {
return notYetImplemented("subtract");
};
KernelBackend2.prototype.multiply = function(a, b) {
return notYetImplemented("multiply");
};
KernelBackend2.prototype.realDivide = function(a, b) {
return notYetImplemented("realDivide");
};
KernelBackend2.prototype.floorDiv = function(a, b) {
return notYetImplemented("floorDiv");
};
KernelBackend2.prototype.sum = function(x, axes) {
return notYetImplemented("sum");
};
KernelBackend2.prototype.prod = function(x, axes) {
return notYetImplemented("prod");
};
KernelBackend2.prototype.unsortedSegmentSum = function(x, segmentIds, numSegments) {
return notYetImplemented("unsortedSegmentSum");
};
KernelBackend2.prototype.argMin = function(x, axis) {
return notYetImplemented("argMin");
};
KernelBackend2.prototype.argMax = function(x, axis) {
return notYetImplemented("argMax");
};
KernelBackend2.prototype.equal = function(a, b) {
return notYetImplemented("equal");
};
KernelBackend2.prototype.notEqual = function(a, b) {
return notYetImplemented("notEqual");
};
KernelBackend2.prototype.less = function(a, b) {
return notYetImplemented("less");
};
KernelBackend2.prototype.lessEqual = function(a, b) {
return notYetImplemented("lessEqual");
};
KernelBackend2.prototype.greater = function(a, b) {
return notYetImplemented("greater");
};
KernelBackend2.prototype.greaterEqual = function(a, b) {
return notYetImplemented("greaterEqual");
};
KernelBackend2.prototype.logicalNot = function(a) {
return notYetImplemented("logicalNot");
};
KernelBackend2.prototype.logicalAnd = function(a, b) {
return notYetImplemented("logicalAnd");
};
KernelBackend2.prototype.logicalOr = function(a, b) {
return notYetImplemented("logicalOr");
};
KernelBackend2.prototype.where = function(condition) {
return notYetImplemented("where");
};
KernelBackend2.prototype.select = function(condition, a, b) {
return notYetImplemented("select");
};
KernelBackend2.prototype.topk = function(x, k, sorted) {
return notYetImplemented("topk");
};
KernelBackend2.prototype.min = function(x, axes) {
return notYetImplemented("min");
};
KernelBackend2.prototype.minimum = function(a, b) {
return notYetImplemented("minimum");
};
KernelBackend2.prototype.mod = function(a, b) {
return notYetImplemented("mod");
};
KernelBackend2.prototype.max = function(x, axes) {
return notYetImplemented("max");
};
KernelBackend2.prototype.maximum = function(a, b) {
return notYetImplemented("maximum");
};
KernelBackend2.prototype.all = function(x, axes) {
return notYetImplemented("all");
};
KernelBackend2.prototype.any = function(x, axes) {
return notYetImplemented("any");
};
KernelBackend2.prototype.squaredDifference = function(a, b) {
return notYetImplemented("squaredDifference");
};
KernelBackend2.prototype.ceil = function(x) {
return notYetImplemented("ceil");
};
KernelBackend2.prototype.floor = function(x) {
return notYetImplemented("floor");
};
KernelBackend2.prototype.round = function(x) {
return notYetImplemented("round");
};
KernelBackend2.prototype.sign = function(x) {
return notYetImplemented("sign");
};
KernelBackend2.prototype.isNaN = function(x) {
return notYetImplemented("isNaN");
};
KernelBackend2.prototype.isInf = function(x) {
return notYetImplemented("isInf");
};
KernelBackend2.prototype.isFinite = function(x) {
return notYetImplemented("isFinite");
};
KernelBackend2.prototype.pow = function(a, b) {
return notYetImplemented("pow");
};
KernelBackend2.prototype.exp = function(x) {
return notYetImplemented("exp");
};
KernelBackend2.prototype.expm1 = function(x) {
return notYetImplemented("expm1");
};
KernelBackend2.prototype.softmax = function(x, dim) {
return notYetImplemented("softmax");
};
KernelBackend2.prototype.log = function(x) {
return notYetImplemented("log");
};
KernelBackend2.prototype.log1p = function(x) {
return notYetImplemented("log1p");
};
KernelBackend2.prototype.sqrt = function(x) {
return notYetImplemented("sqrt");
};
KernelBackend2.prototype.rsqrt = function(x) {
return notYetImplemented("rsqrt");
};
KernelBackend2.prototype.square = function(x) {
return notYetImplemented("square");
};
KernelBackend2.prototype.reciprocal = function(x) {
return notYetImplemented("reciprocal");
};
KernelBackend2.prototype.relu = function(x) {
return notYetImplemented("relu");
};
KernelBackend2.prototype.relu6 = function(x) {
return notYetImplemented("relu6");
};
KernelBackend2.prototype.prelu = function(x, a) {
return notYetImplemented("prelu");
};
KernelBackend2.prototype.elu = function(x) {
return notYetImplemented("elu");
};
KernelBackend2.prototype.eluDer = function(dy, y) {
return notYetImplemented("eluDer");
};
KernelBackend2.prototype.selu = function(x) {
return notYetImplemented("selu");
};
KernelBackend2.prototype.int = function(x) {
return notYetImplemented("int");
};
KernelBackend2.prototype.clip = function(x, min2, max2) {
return notYetImplemented("clip");
};
KernelBackend2.prototype.abs = function(x) {
return notYetImplemented("abs");
};
KernelBackend2.prototype.complexAbs = function(x) {
return notYetImplemented("complexAbs");
};
KernelBackend2.prototype.sigmoid = function(x) {
return notYetImplemented("sigmoid");
};
KernelBackend2.prototype.softplus = function(x) {
return notYetImplemented("softplus");
};
KernelBackend2.prototype.sin = function(x) {
return notYetImplemented("sin");
};
KernelBackend2.prototype.cos = function(x) {
return notYetImplemented("cos");
};
KernelBackend2.prototype.tan = function(x) {
return notYetImplemented("tan");
};
KernelBackend2.prototype.asin = function(x) {
return notYetImplemented("asin");
};
KernelBackend2.prototype.acos = function(x) {
return notYetImplemented("acos");
};
KernelBackend2.prototype.atan = function(x) {
return notYetImplemented("atan");
};
KernelBackend2.prototype.atan2 = function(a, b) {
return notYetImplemented("atan2");
};
KernelBackend2.prototype.sinh = function(x) {
return notYetImplemented("sinh");
};
KernelBackend2.prototype.cosh = function(x) {
return notYetImplemented("cosh");
};
KernelBackend2.prototype.tanh = function(x) {
return notYetImplemented("tanh");
};
KernelBackend2.prototype.asinh = function(x) {
return notYetImplemented("asinh");
};
KernelBackend2.prototype.acosh = function(x) {
return notYetImplemented("acosh");
};
KernelBackend2.prototype.atanh = function(x) {
return notYetImplemented("atanh");
};
KernelBackend2.prototype.erf = function(x) {
return notYetImplemented("erf");
};
KernelBackend2.prototype.step = function(x, alpha) {
return notYetImplemented("step");
};
KernelBackend2.prototype.fusedConv2d = function(_a) {
var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
return notYetImplemented("fusedConv2d");
};
KernelBackend2.prototype.conv2d = function(x, filter, convInfo) {
return notYetImplemented("conv2d");
};
KernelBackend2.prototype.conv2dDerInput = function(dy, filter, convInfo) {
return notYetImplemented("conv2dDerInput");
};
KernelBackend2.prototype.conv2dDerFilter = function(x, dY, convInfo) {
return notYetImplemented("conv2dDerFilter");
};
KernelBackend2.prototype.fusedDepthwiseConv2D = function(_a) {
var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
return notYetImplemented("fusedDepthwiseConv2D");
};
KernelBackend2.prototype.depthwiseConv2D = function(input, filter, convInfo) {
return notYetImplemented("depthwiseConv2D");
};
KernelBackend2.prototype.depthwiseConv2DDerInput = function(dy, filter, convInfo) {
return notYetImplemented("depthwiseConv2DDerInput");
};
KernelBackend2.prototype.depthwiseConv2DDerFilter = function(x, dY, convInfo) {
return notYetImplemented("depthwiseConv2DDerFilter");
};
KernelBackend2.prototype.conv3d = function(x, filter, convInfo) {
return notYetImplemented("conv3d");
};
KernelBackend2.prototype.conv3dDerInput = function(dy, filter, convInfo) {
return notYetImplemented("conv3dDerInput");
};
KernelBackend2.prototype.conv3dDerFilter = function(x, dY, convInfo) {
return notYetImplemented("conv3dDerFilter");
};
KernelBackend2.prototype.maxPool = function(x, convInfo) {
return notYetImplemented("maxPool");
};
KernelBackend2.prototype.maxPoolBackprop = function(dy, x, y, convInfo) {
return notYetImplemented("maxPoolBackprop");
};
KernelBackend2.prototype.avgPool = function(x, convInfo) {
return notYetImplemented("avgPool");
};
KernelBackend2.prototype.avgPoolBackprop = function(dy, x, convInfo) {
return notYetImplemented("avgPoolBackprop");
};
KernelBackend2.prototype.avgPool3d = function(x, convInfo) {
return notYetImplemented("avgPool3d");
};
KernelBackend2.prototype.avgPool3dBackprop = function(dy, x, convInfo) {
return notYetImplemented("avgPool3dBackprop");
};
KernelBackend2.prototype.maxPool3d = function(x, convInfo) {
return notYetImplemented("maxPool3d");
};
KernelBackend2.prototype.maxPool3dBackprop = function(dy, x, y, convInfo) {
return notYetImplemented("maxPool3dBackprop");
};
KernelBackend2.prototype.reshape = function(x, shape) {
return notYetImplemented("reshape");
};
KernelBackend2.prototype.cast = function(x, dtype) {
return notYetImplemented("cast");
};
KernelBackend2.prototype.tile = function(x, reps) {
return notYetImplemented("tile");
};
KernelBackend2.prototype.pad = function(x, paddings, constantValue) {
return notYetImplemented("pad");
};
KernelBackend2.prototype.transpose = function(x, perm) {
return notYetImplemented("transpose");
};
KernelBackend2.prototype.gather = function(x, indices, axis) {
return notYetImplemented("gather");
};
KernelBackend2.prototype.gatherND = function(x, indices) {
return notYetImplemented("gatherND");
};
KernelBackend2.prototype.scatterND = function(indices, updates, shape) {
return notYetImplemented("scatterND");
};
KernelBackend2.prototype.batchToSpaceND = function(x, blockShape, crops) {
return notYetImplemented("batchToSpaceND");
};
KernelBackend2.prototype.spaceToBatchND = function(x, blockShape, paddings) {
return notYetImplemented("spaceToBatchND");
};
KernelBackend2.prototype.resizeBilinear = function(x, newHeight, newWidth, alignCorners) {
return notYetImplemented("resizeBilinear");
};
KernelBackend2.prototype.resizeBilinearBackprop = function(dy, x, alignCorners) {
return notYetImplemented("resizeBilinearBackprop");
};
KernelBackend2.prototype.resizeNearestNeighbor = function(x, newHEight, newWidth, alignCorners) {
return notYetImplemented("resizeNearestNeighbor");
};
KernelBackend2.prototype.resizeNearestNeighborBackprop = function(dy, x, alignCorners) {
return notYetImplemented("resizeNearestNeighborBackprop");
};
KernelBackend2.prototype.batchNorm = function(x, mean2, variance, offset, scale, varianceEpsilon) {
return notYetImplemented("batchNorm");
};
KernelBackend2.prototype.localResponseNormalization4D = function(x, radius, bias, alpha, beta) {
return notYetImplemented("localResponseNormalization4D");
};
KernelBackend2.prototype.LRNGrad = function(dy, inputImage, outputImage, radius, bias, alpha, beta) {
return notYetImplemented("LRNGrad");
};
KernelBackend2.prototype.multinomial = function(logits, normalized, numSamples, seed) {
return notYetImplemented("multinomial");
};
KernelBackend2.prototype.oneHot = function(indices, depth, onValue, offValue) {
return notYetImplemented("oneHot");
};
KernelBackend2.prototype.cumsum = function(x, axis, exclusive, reverse2) {
return notYetImplemented("cumsum");
};
KernelBackend2.prototype.nonMaxSuppression = function(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
return notYetImplemented("nonMaxSuppression");
};
KernelBackend2.prototype.fft = function(x) {
return notYetImplemented("fft");
};
KernelBackend2.prototype.ifft = function(x) {
return notYetImplemented("ifft");
};
KernelBackend2.prototype.complex = function(real2, imag2) {
return notYetImplemented("complex");
};
KernelBackend2.prototype.real = function(input) {
return notYetImplemented("real");
};
KernelBackend2.prototype.imag = function(input) {
return notYetImplemented("imag");
};
KernelBackend2.prototype.cropAndResize = function(image2, boxes, boxIndex, cropSize, method, extrapolationValue) {
return notYetImplemented("cropAndResize");
};
KernelBackend2.prototype.depthToSpace = function(x, blockSize, dataFormat) {
return notYetImplemented("depthToSpace");
};
KernelBackend2.prototype.split = function(value, sizeSplits, axis) {
return notYetImplemented("split");
};
KernelBackend2.prototype.sparseToDense = function(sparseIndices, sparseValues, outputShape, defaultValue) {
return notYetImplemented("sparseToDense");
};
KernelBackend2.prototype.diag = function(x) {
return notYetImplemented("diag");
};
KernelBackend2.prototype.fill = function(shape, value, dtype) {
return notYetImplemented("fill");
};
KernelBackend2.prototype.onesLike = function(x) {
return notYetImplemented("onesLike");
};
KernelBackend2.prototype.zerosLike = function(x) {
return notYetImplemented("zerosLike");
};
KernelBackend2.prototype.linspace = function(start, stop, num) {
return notYetImplemented("linspace");
};
KernelBackend2.prototype.dispose = function() {
return notYetImplemented("dispose");
};
return KernelBackend2;
}();
function notYetImplemented(kernelName) {
throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen");
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function shuffle(array) {
var counter = array.length;
var temp = 0;
var index = 0;
while (counter > 0) {
index = Math.random() * counter | 0;
counter--;
temp = array[counter];
array[counter] = array[index];
array[index] = temp;
}
}
function clamp(min2, x, max2) {
return Math.max(min2, Math.min(x, max2));
}
function nearestLargerEven(val) {
return val % 2 === 0 ? val : val + 1;
}
function sum(arr) {
var sum2 = 0;
for (var i = 0; i < arr.length; i++) {
sum2 += arr[i];
}
return sum2;
}
function randUniform(a, b) {
var r = Math.random();
return b * r + (1 - r) * a;
}
function distSquared(a, b) {
var result = 0;
for (var i = 0; i < a.length; i++) {
var diff = Number(a[i]) - Number(b[i]);
result += diff * diff;
}
return result;
}
function assert(expr, msg) {
if (!expr) {
throw new Error(typeof msg === "string" ? msg : msg());
}
}
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) {
if (errorMessagePrefix === void 0) {
errorMessagePrefix = "";
}
assert(arraysEqual(shapeA, shapeB), function() {
return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match");
});
}
function assertNonNull(a) {
assert(a != null, function() {
return "The input to the tensor constructor must be a non-null value.";
});
}
function flatten(arr, result, skipTypedArray) {
if (result === void 0) {
result = [];
}
if (skipTypedArray === void 0) {
skipTypedArray = false;
}
if (result == null) {
result = [];
}
if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
for (var i = 0; i < arr.length; ++i) {
flatten(arr[i], result, skipTypedArray);
}
} else {
result.push(arr);
}
return result;
}
function sizeFromShape(shape) {
if (shape.length === 0) {
return 1;
}
var size = shape[0];
for (var i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
function isScalarShape(shape) {
return shape.length === 0;
}
function arraysEqual(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (var i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
function isInt(a) {
return a % 1 === 0;
}
function tanh(x) {
if (Math.tanh != null) {
return Math.tanh(x);
}
if (x === Infinity) {
return 1;
} else if (x === -Infinity) {
return -1;
} else {
var e2x = Math.exp(2 * x);
return (e2x - 1) / (e2x + 1);
}
}
function sizeToSquarishShape(size) {
var width = Math.ceil(Math.sqrt(size));
return [width, Math.ceil(size / width)];
}
function createShuffledIndices(n) {
var shuffledIndices = new Uint32Array(n);
for (var i = 0; i < n; ++i) {
shuffledIndices[i] = i;
}
shuffle(shuffledIndices);
return shuffledIndices;
}
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + " ".repeat(size - a.length);
}
function repeatedTry(checkFn, delayFn, maxCounter) {
if (delayFn === void 0) {
delayFn = function(counter) {
return 0;
};
}
return new Promise(function(resolve, reject) {
var tryCount = 0;
var tryFn = function() {
if (checkFn()) {
resolve();
return;
}
tryCount++;
var nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
setTimeout(tryFn, nextBackoff);
};
tryFn();
});
}
function inferFromImplicitShape(shape, size) {
var shapeProd = 1;
var implicitIdx = -1;
for (var i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
} else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error("Shapes can only have 1 implicit size. " + ("Found -1 at dim " + implicitIdx + " and dim " + i));
}
implicitIdx = i;
} else if (shape[i] < 0) {
throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error("Size(" + size + ") must match the product of shape " + shape);
}
return shape;
}
if (shapeProd === 0) {
throw Error("Cannot infer the missing size in [" + shape + "] when there are 0 elements");
}
if (size % shapeProd !== 0) {
throw Error("The implicit shape can't be a fractional number. " + ("Got " + size + " / " + shapeProd));
}
var newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
function parseAxisParam(axis, shape) {
var rank = shape.length;
axis = axis == null ? shape.map(function(s, i) {
return i;
}) : [].concat(axis);
assert(axis.every(function(ax) {
return ax >= -rank && ax < rank;
}), function() {
return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " + ("got axis " + axis);
});
assert(axis.every(function(ax) {
return isInt(ax);
}), function() {
return "All values in axis param must be integers but " + ("got axis " + axis);
});
return axis.map(function(a) {
return a < 0 ? rank + a : a;
});
}
function squeezeShape(shape, axis) {
var newShape = [];
var keptDims = [];
var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
var axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort();
var j = 0;
for (var i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1");
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return {newShape, keptDims};
}
function getTypedArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === "float32") {
values = new Float32Array(size);
} else if (dtype === "int32") {
values = new Int32Array(size);
} else if (dtype === "bool") {
values = new Uint8Array(size);
} else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
function getArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === "float32") {
values = new Float32Array(size);
} else if (dtype === "int32") {
values = new Int32Array(size);
} else if (dtype === "bool") {
values = new Uint8Array(size);
} else if (dtype === "string") {
values = new Array(size);
} else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
function checkConversionForErrors(vals, dtype) {
for (var i = 0; i < vals.length; i++) {
var num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error("A tensor of type " + dtype + " being uploaded contains " + num + ".");
}
}
}
function isValidDtype(dtype) {
return dtype === "bool" || dtype === "complex64" || dtype === "float32" || dtype === "int32" || dtype === "string";
}
function hasEncodingLoss(oldType, newType) {
if (newType === "complex64") {
return false;
}
if (newType === "float32" && oldType !== "complex64") {
return false;
}
if (newType === "int32" && oldType !== "float32" && oldType !== "complex64") {
return false;
}
if (newType === "bool" && oldType === "bool") {
return false;
}
return true;
}
function isTypedArray(a) {
return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array;
}
function bytesPerElement(dtype) {
if (dtype === "float32" || dtype === "int32") {
return 4;
} else if (dtype === "complex64") {
return 8;
} else if (dtype === "bool") {
return 1;
} else {
throw new Error("Unknown dtype " + dtype);
}
}
function bytesFromStringArray(arr) {
if (arr == null) {
return 0;
}
var bytes = 0;
arr.forEach(function(x) {
return bytes += x.length;
});
return bytes;
}
function isString(value) {
return typeof value === "string" || value instanceof String;
}
function isBoolean(value) {
return typeof value === "boolean";
}
function isNumber(value) {
return typeof value === "number";
}
function inferDtype(values) {
if (Array.isArray(values)) {
return inferDtype(values[0]);
}
if (values instanceof Float32Array) {
return "float32";
} else if (values instanceof Int32Array || values instanceof Uint8Array) {
return "int32";
} else if (isNumber(values)) {
return "float32";
} else if (isString(values)) {
return "string";
} else if (isBoolean(values)) {
return "bool";
}
return "float32";
}
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
function nearestDivisor(size, start) {
for (var i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
function computeStrides(shape) {
var rank = shape.length;
if (rank < 2) {
return [];
}
var strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (var i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
function createNestedArray(offset, shape, a) {
var ret = new Array();
if (shape.length === 1) {
var d = shape[0];
for (var i = 0; i < d; i++) {
ret[i] = a[offset + i];
}
} else {
var d = shape[0];
var rest = shape.slice(1);
var len = rest.reduce(function(acc, c) {
return acc * c;
});
for (var i = 0; i < d; i++) {
ret[i] = createNestedArray(offset + i * len, rest, a);
}
}
return ret;
}
function toNestedArray(shape, a) {
if (shape.length === 0) {
return a[0];
}
var size = shape.reduce(function(acc, c) {
return acc * c;
});
if (size === 0) {
return [];
}
if (size !== a.length) {
throw new Error("[" + shape + "] does not match the input size " + a.length + ".");
}
return createNestedArray(0, shape, a);
}
function makeOnesTypedArray(size, dtype) {
var array = makeZerosTypedArray(size, dtype);
for (var i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === "float32" || dtype === "complex64") {
return new Float32Array(size);
} else if (dtype === "int32") {
return new Int32Array(size);
} else if (dtype === "bool") {
return new Uint8Array(size);
} else {
throw new Error("Unknown data type " + dtype);
}
}
function makeZerosNestedTypedArray(shape, dtype) {
var size = shape.reduce(function(prev, curr) {
return prev * curr;
}, 1);
if (dtype == null || dtype === "float32") {
return toNestedArray(shape, new Float32Array(size));
} else if (dtype === "int32") {
return toNestedArray(shape, new Int32Array(size));
} else if (dtype === "bool") {
return toNestedArray(shape, new Uint8Array(size));
} else {
throw new Error("Unknown data type " + dtype);
}
}
function assertNonNegativeIntegerDimensions(shape) {
shape.forEach(function(dimSize) {
assert(Number.isInteger(dimSize) && dimSize >= 0, function() {
return "Tensor must have a shape comprised of positive integers but got " + ("shape [" + shape + "].");
});
});
}
function locToIndex(locs, rank, strides) {
if (rank === 0) {
return 0;
} else if (rank === 1) {
return locs[0];
}
var index = locs[locs.length - 1];
for (var i = 0; i < locs.length - 1; ++i) {
index += strides[i] * locs[i];
}
return index;
}
function indexToLoc(index, rank, strides) {
if (rank === 0) {
return [];
} else if (rank === 1) {
return [index];
}
var locs = new Array(rank);
for (var i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / strides[i]);
index -= locs[i] * strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
function isPromise(object) {
return object && object.then && typeof object.then === "function";
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TENSORFLOWJS_FLAGS_PREFIX = "tfjsflags";
var Environment = function() {
function Environment2(global2) {
this.global = global2;
this.flags = {};
this.flagRegistry = {};
this.urlFlags = {};
this.populateURLFlags();
}
Environment2.prototype.setPlatform = function(platformName, platform) {
if (this.platform != null) {
console.warn("Platform " + this.platformName + " has already been set. " + ("Overwriting the platform with " + platform + "."));
}
this.platformName = platformName;
this.platform = platform;
};
Environment2.prototype.registerFlag = function(flagName, evaluationFn, setHook) {
this.flagRegistry[flagName] = {evaluationFn, setHook};
if (this.urlFlags[flagName] != null) {
var flagValue = this.urlFlags[flagName];
console.warn("Setting feature override from URL " + flagName + ": " + flagValue + ".");
this.set(flagName, flagValue);
}
};
Environment2.prototype.getAsync = function(flagName) {
return __awaiter(this, void 0, void 0, function() {
var _a, _b;
return __generator(this, function(_c) {
switch (_c.label) {
case 0:
if (flagName in this.flags) {
return [2, this.flags[flagName]];
}
_a = this.flags;
_b = flagName;
return [4, this.evaluateFlag(flagName)];
case 1:
_a[_b] = _c.sent();
return [2, this.flags[flagName]];
}
});
});
};
Environment2.prototype.get = function(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
var flagValue = this.evaluateFlag(flagName);
if (isPromise(flagValue)) {
throw new Error("Flag " + flagName + " cannot be synchronously evaluated. Please use getAsync() instead.");
}
this.flags[flagName] = flagValue;
return this.flags[flagName];
};
Environment2.prototype.getNumber = function(flagName) {
return this.get(flagName);
};
Environment2.prototype.getBool = function(flagName) {
return this.get(flagName);
};
Environment2.prototype.getFlags = function() {
return this.flags;
};
Object.defineProperty(Environment2.prototype, "features", {
get: function() {
return this.flags;
},
enumerable: true,
configurable: true
});
Environment2.prototype.set = function(flagName, value) {
if (this.flagRegistry[flagName] == null) {
throw new Error("Cannot set flag " + flagName + " as it has not been registered.");
}
this.flags[flagName] = value;
if (this.flagRegistry[flagName].setHook != null) {
this.flagRegistry[flagName].setHook(value);
}
};
Environment2.prototype.evaluateFlag = function(flagName) {
if (this.flagRegistry[flagName] == null) {
throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found.");
}
return this.flagRegistry[flagName].evaluationFn();
};
Environment2.prototype.setFlags = function(flags) {
this.flags = Object.assign({}, flags);
};
Environment2.prototype.reset = function() {
this.flags = {};
this.urlFlags = {};
this.populateURLFlags();
};
Environment2.prototype.populateURLFlags = function() {
var _this = this;
if (typeof this.global === "undefined" || typeof this.global.location === "undefined" || typeof this.global.location.search === "undefined") {
return;
}
var urlParams = getQueryParams(this.global.location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(",");
keyValues.forEach(function(keyValue) {
var _a = keyValue.split(":"), key = _a[0], value = _a[1];
_this.urlFlags[key] = parseValue(key, value);
});
}
};
return Environment2;
}();
function getQueryParams(queryString) {
var params = {};
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function(s) {
var t = [];
for (var _i2 = 1; _i2 < arguments.length; _i2++) {
t[_i2 - 1] = arguments[_i2];
}
decodeParam(params, t[0], t[1]);
return t.join("=");
});
return params;
}
function decodeParam(params, name, value) {
params[decodeURIComponent(name)] = decodeURIComponent(value || "");
}
function parseValue(flagName, value) {
value = value.toLowerCase();
if (value === "true" || value === "false") {
return value === "true";
} else if ("" + +value === value) {
return +value;
}
throw new Error("Could not parse value flag value " + value + " for flag " + flagName + ".");
}
function env() {
return exports.ENV;
}
exports.ENV = null;
function setEnvironmentGlobal(environment) {
exports.ENV = environment;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var globalNameSpace;
function getGlobalNamespace() {
if (globalNameSpace == null) {
var ns = void 0;
if (typeof window !== "undefined") {
ns = window;
} else if (typeof global !== "undefined") {
ns = global;
} else if (typeof process !== "undefined") {
ns = process;
} else if (typeof self !== "undefined") {
ns = self;
} else {
throw new Error("Could not find a global object");
}
globalNameSpace = ns;
}
return globalNameSpace;
}
function getGlobalMap() {
var ns = getGlobalNamespace();
if (ns._tfGlobals == null) {
ns._tfGlobals = new Map();
}
return ns._tfGlobals;
}
function getGlobal(key, init) {
var globalMap = getGlobalMap();
if (globalMap.has(key)) {
return globalMap.get(key);
} else {
var singleton = init();
globalMap.set(key, singleton);
return globalMap.get(key);
}
}
var Abs = "Abs";
var Acos = "Acos";
var Acosh = "Acosh";
var Add = "Add";
var AddN = "AddN";
var All = "All";
var Any = "Any";
var ArgMax = "ArgMax";
var ArgMin = "ArgMin";
var Asin = "Asin";
var Asinh = "Asinh";
var Atan = "Atan";
var Atanh = "Atanh";
var Atan2 = "Atan2";
var AvgPool = "AvgPool";
var AvgPoolBackprop = "AvgPoolBackprop";
var AvgPool3D = "AvgPool3D";
var AvgPool3DBackprop = "AvgPool3DBackprop";
var BatchMatMul = "BatchMatMul";
var BatchToSpaceND = "BatchToSpaceND";
var BroadcastTo = "BroadcastTo";
var Cast = "Cast";
var Ceil = "Ceil";
var ClipByValue = "ClipByValue";
var Complex = "Complex";
var Concat = "Concat";
var Conv2D = "Conv2D";
var Conv2DBackpropFilter = "Conv2DBackpropFilter";
var Conv2DBackpropInput = "Conv2DBackpropInput";
var Conv3D = "Conv3D";
var Conv3DBackpropFilterV2 = "Conv3DBackpropFilterV2";
var Conv3DBackpropInputV2 = "Conv3DBackpropInputV2";
var Cos = "Cos";
var Cosh = "Cosh";
var Cumsum = "Cumsum";
var CropAndResize = "CropAndResize";
var DepthToSpace = "DepthToSpace";
var DepthwiseConv2dNative = "DepthwiseConv2dNative";
var DepthwiseConv2dNativeBackpropFilter = "DepthwiseConv2dNativeBackpropFilter";
var DepthwiseConv2dNativeBackpropInput = "DepthwiseConv2dNativeBackpropInput";
var Diag = "Diag";
var Dilation2D = "Dilation2D";
var Dilation2DBackpropInput = "Dilation2DBackpropInput";
var Dilation2DBackpropFilter = "Dilation2DBackpropFilter";
var Div = "Div";
var Elu = "Elu";
var EluGrad = "EluGrad";
var Erf = "Erf";
var Equal = "Equal";
var Exp = "Exp";
var Expm1 = "Expm1";
var FFT = "FFT";
var Fill = "Fill";
var FlipLeftRight = "FlipLeftRight";
var Floor = "Floor";
var FloorDiv = "FloorDiv";
var FusedBatchNorm = "FusedBatchNorm";
var GatherV2 = "GatherV2";
var GatherNd = "GatherNd";
var Greater = "Greater";
var GreaterEqual = "GreaterEqual";
var Identity = "Identity";
var IFFT = "IFFT";
var Imag = "Imag";
var IsFinite = "IsFinite";
var IsInf = "IsInf";
var IsNan = "IsNan";
var Less = "Less";
var LessEqual = "LessEqual";
var LinSpace = "LinSpace";
var Log = "Log";
var Log1p = "Log1p";
var LogicalAnd = "LogicalAnd";
var LogicalNot = "LogicalNot";
var LogicalOr = "LogicalOr";
var LogSoftmax = "LogSoftmax";
var LRN = "LRN";
var LRNBackprop = "LRNBackprop";
var Max = "Max";
var Maximum = "Maximum";
var MaxPool = "MaxPool";
var MaxPoolBackprop = "MaxPoolBackprop";
var MaxPool3D = "MaxPool3D";
var MaxPool3DBackprop = "MaxPool3DBackprop";
var MaxPoolWithArgmax = "MaxPoolWithArgmax";
var Mean = "Mean";
var Min = "Min";
var Minimum = "Minimum";
var MirrorPad = "MirrorPad";
var Mod = "Mod";
var Multiply = "Multiply";
var Negate = "Negate";
var NotEqual = "NotEqual";
var NonMaxSuppressionV3 = "NonMaxSuppressionV3";
var NonMaxSuppressionV4 = "NonMaxSuppressionV4";
var NonMaxSuppressionV5 = "NonMaxSuppressionV5";
var OnesLike = "OnesLike";
var OneHot = "OneHot";
var PadV2 = "PadV2";
var Pool = "Pool";
var Pow = "Pow";
var Prelu = "Prelu";
var Prod = "Prod";
var Range = "Range";
var Real = "Real";
var Reciprocal = "Reciprocal";
var Relu = "Relu";
var Reshape = "Reshape";
var ResizeNearestNeighbor = "ResizeNearestNeighbor";
var ResizeNearestNeighborGrad = "ResizeNearestNeighborGrad";
var ResizeBilinear = "ResizeBilinear";
var ResizeBilinearGrad = "ResizeBilinearGrad";
var Relu6 = "Relu6";
var Reverse = "Reverse";
var Round = "Round";
var Rsqrt = "Rsqrt";
var ScatterNd = "ScatterNd";
var SelectV2 = "SelectV2";
var Selu = "Selu";
var Slice = "Slice";
var Sin = "Sin";
var Sinh = "Sinh";
var Sign = "Sign";
var Sigmoid = "Sigmoid";
var Softplus = "Softplus";
var Sqrt = "Sqrt";
var Sum = "Sum";
var SpaceToBatchND = "SpaceToBatchND";
var SplitV = "SplitV";
var Softmax = "Softmax";
var SquaredDifference = "SquaredDifference";
var Square = "Square";
var Sub = "Sub";
var SparseToDense = "SparseToDense";
var StridedSlice = "StridedSlice";
var Tan = "Tan";
var Tanh = "Tanh";
var Tile = "Tile";
var TopK = "TopK";
var Transpose = "Transpose";
var Unique = "Unique";
var Unpack = "Unpack";
var UnsortedSegmentSum = "UnsortedSegmentSum";
var ZerosLike = "ZerosLike";
var Step = "Step";
var FromPixels = "FromPixels";
var RotateWithOffset = "RotateWithOffset";
var _FusedMatMul = "_FusedMatMul";
var FusedConv2D = "FusedConv2D";
var FusedDepthwiseConv2D = "FusedDepthwiseConv2D";
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernelRegistry = getGlobal("kernelRegistry", function() {
return new Map();
});
var gradRegistry = getGlobal("gradRegistry", function() {
return new Map();
});
function getKernel(kernelName, backendName) {
var key = makeKey(kernelName, backendName);
return kernelRegistry.get(key);
}
function getGradient(kernelName) {
return gradRegistry.get(kernelName);
}
function getKernelsForBackend(backendName) {
var it = kernelRegistry.entries();
var result = [];
while (true) {
var _a = it.next(), done = _a.done, value = _a.value;
if (done) {
break;
}
var key = value[0], config = value[1];
var backend2 = key.split("_")[0];
if (backend2 === backendName) {
result.push(config);
}
}
return result;
}
function registerKernel(config) {
var kernelName = config.kernelName, backendName = config.backendName;
var key = makeKey(kernelName, backendName);
if (kernelRegistry.has(key)) {
console.warn("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is already registered"));
}
kernelRegistry.set(key, config);
}
function registerGradient(config) {
var kernelName = config.kernelName;
if (gradRegistry.has(kernelName)) {
if (env().getBool("DEBUG")) {
console.warn("Overriding the gradient for '" + kernelName + "'");
}
}
gradRegistry.set(kernelName, config);
}
function unregisterKernel(kernelName, backendName) {
var key = makeKey(kernelName, backendName);
if (!kernelRegistry.has(key)) {
throw new Error("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is not registered"));
}
kernelRegistry.delete(key);
}
function unregisterGradient(kernelName) {
if (!gradRegistry.has(kernelName)) {
throw new Error("The gradient '" + kernelName + "' for backend is not registered");
}
gradRegistry.delete(kernelName);
}
function copyRegisteredKernels(registeredBackendName, newBackendName) {
var kernels = getKernelsForBackend(registeredBackendName);
kernels.forEach(function(kernelConfig) {
var newKernelConfig = Object.assign({}, kernelConfig, {backendName: newBackendName});
registerKernel(newKernelConfig);
});
}
function makeKey(kernelName, backendName) {
return backendName + "_" + kernelName;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createScalarValue(value, dtype) {
if (dtype === "string") {
return encodeString(value);
}
return toTypedArray([value], dtype);
}
function noConversionNeeded(a, dtype) {
return a instanceof Float32Array && dtype === "float32" || a instanceof Int32Array && dtype === "int32" || a instanceof Uint8Array && dtype === "bool";
}
function toTypedArray(a, dtype) {
if (dtype === "string") {
throw new Error("Cannot convert a string[] to a TypedArray");
}
if (Array.isArray(a)) {
a = flatten(a);
}
if (env().getBool("DEBUG")) {
checkConversionForErrors(a, dtype);
}
if (noConversionNeeded(a, dtype)) {
return a;
}
if (dtype == null || dtype === "float32" || dtype === "complex64") {
return new Float32Array(a);
} else if (dtype === "int32") {
return new Int32Array(a);
} else if (dtype === "bool") {
var bool = new Uint8Array(a.length);
for (var i = 0; i < bool.length; ++i) {
if (Math.round(a[i]) !== 0) {
bool[i] = 1;
}
}
return bool;
} else {
throw new Error("Unknown data type " + dtype);
}
}
function now() {
return env().platform.now();
}
function fetch$1(path, requestInits) {
return env().platform.fetch(path, requestInits);
}
function encodeString(s, encoding) {
if (encoding === void 0) {
encoding = "utf-8";
}
encoding = encoding || "utf-8";
return env().platform.encode(s, encoding);
}
function decodeString(bytes, encoding) {
if (encoding === void 0) {
encoding = "utf-8";
}
encoding = encoding || "utf-8";
return env().platform.decode(bytes, encoding);
}
var util = {
__proto__: null,
createScalarValue,
toTypedArray,
now,
fetch: fetch$1,
encodeString,
decodeString,
shuffle,
clamp,
nearestLargerEven,
sum,
randUniform,
distSquared,
assert,
assertShapesMatch,
assertNonNull,
flatten,
sizeFromShape,
isScalarShape,
arraysEqual,
isInt,
tanh,
sizeToSquarishShape,
createShuffledIndices,
rightPad,
repeatedTry,
inferFromImplicitShape,
parseAxisParam,
squeezeShape,
getTypedArrayFromDType,
getArrayFromDType,
checkConversionForErrors,
isValidDtype,
hasEncodingLoss,
isTypedArray,
bytesPerElement,
bytesFromStringArray,
isString,
isBoolean,
isNumber,
inferDtype,
isFunction,
nearestDivisor,
computeStrides,
toNestedArray,
makeOnesTypedArray,
makeZerosTypedArray,
makeZerosNestedTypedArray,
assertNonNegativeIntegerDimensions,
locToIndex,
indexToLoc,
isPromise
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Profiler = function() {
function Profiler2(backendTimer, logger) {
this.backendTimer = backendTimer;
this.logger = logger;
if (logger == null) {
this.logger = new Logger();
}
}
Profiler2.prototype.profileKernel = function(kernelName, inputs, f) {
var outputs;
var holdResultWrapperFn = function() {
outputs = f();
};
var timer = this.backendTimer.time(holdResultWrapperFn);
var _loop_1 = function(i2) {
var output = outputs[i2];
output.data().then(function(tensorVals) {
checkComputationForErrors(tensorVals, output.dtype, kernelName);
});
};
for (var i = 0; i < outputs.length; i++) {
_loop_1(i);
}
var kernelProfile = {
kernelName,
outputs,
inputs,
timeMs: timer.then(function(timing) {
return timing.kernelMs;
}),
extraInfo: timer.then(function(timing) {
return timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : "";
})
};
return kernelProfile;
};
Profiler2.prototype.logKernelProfile = function(kernelProfile) {
var _this = this;
var kernelName = kernelProfile.kernelName, outputs = kernelProfile.outputs, timeMs = kernelProfile.timeMs, inputs = kernelProfile.inputs, extraInfo = kernelProfile.extraInfo;
outputs.forEach(function(result) {
Promise.all([result.data(), timeMs, extraInfo]).then(function(valueContainer) {
_this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
});
});
};
return Profiler2;
}();
function checkComputationForErrors(vals, dtype, kernelName) {
if (dtype !== "float32") {
return false;
}
for (var i = 0; i < vals.length; i++) {
var num = vals[i];
if (isNaN(num) || !isFinite(num)) {
console.warn("Found " + num + " in the result of '" + kernelName + "'");
return true;
}
}
return false;
}
var Logger = function() {
function Logger2() {
}
Logger2.prototype.logKernelProfile = function(name, result, vals, timeMs, inputs, extraInfo) {
var time2 = typeof timeMs === "number" ? rightPad(timeMs + "ms", 9) : timeMs["error"];
var paddedName = rightPad(name, 25);
var rank = result.rank;
var size = result.size;
var shape = rightPad(result.shape.toString(), 14);
var inputShapesDescription = "";
for (var name_1 in inputs) {
var input = inputs[name_1];
if (input != null) {
var inputShape = input.shape || result.shape;
var inputRank = inputShape.length;
inputShapesDescription += name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : "") + " ";
}
}
console.log("%c" + paddedName + " %c" + time2 + " %c" + rank + "D " + shape + " %c" + size + " %c" + inputShapesDescription + " %c" + extraInfo, "font-weight:bold", "color:red", "color:blue", "color: orange", "color: green", "color: steelblue");
};
return Logger2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getFilteredNodesXToY(tape, xs, y) {
var tensorsFromX = {};
var nodesFromX = {};
for (var i = 0; i < xs.length; i++) {
tensorsFromX[xs[i].id] = true;
}
for (var i = 0; i < tape.length; i++) {
var node = tape[i];
var nodeInputs = node.inputs;
for (var inputName in nodeInputs) {
var input = nodeInputs[inputName];
var anyInputFromX = false;
for (var j = 0; j < xs.length; j++) {
if (tensorsFromX[input.id]) {
node.outputs.forEach(function(output) {
return tensorsFromX[output.id] = true;
});
anyInputFromX = true;
nodesFromX[node.id] = true;
break;
}
}
if (anyInputFromX) {
break;
}
}
}
var tensorsLeadToY = {};
tensorsLeadToY[y.id] = true;
var nodesToY = {};
for (var i = tape.length - 1; i >= 0; i--) {
var node = tape[i];
var nodeInputs = node.inputs;
for (var j = 0; j < node.outputs.length; j++) {
if (tensorsLeadToY[node.outputs[j].id]) {
for (var inputName in nodeInputs) {
tensorsLeadToY[nodeInputs[inputName].id] = true;
nodesToY[node.id] = true;
}
break;
}
}
}
var filteredTape = [];
for (var i = 0; i < tape.length; i++) {
var node = tape[i];
if (nodesFromX[node.id] && nodesToY[node.id]) {
var prunedInputs = {};
for (var inputName in node.inputs) {
var nodeInput = node.inputs[inputName];
if (tensorsFromX[nodeInput.id]) {
prunedInputs[inputName] = nodeInput;
}
}
var prunedNode = Object.assign({}, node);
prunedNode.inputs = prunedInputs;
prunedNode.outputs = node.outputs;
filteredTape.push(prunedNode);
}
}
return filteredTape;
}
function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy2, add2) {
var _loop_1 = function(i2) {
var node = filteredTape[i2];
var dys = [];
node.outputs.forEach(function(o) {
var gradTensor = tensorAccumulatedGradientMap[o.id];
if (gradTensor != null) {
dys.push(gradTensor);
} else {
dys.push(null);
}
});
if (node.gradient == null) {
throw new Error("Cannot compute gradient: gradient function not found " + ("for " + node.kernelName + "."));
}
var inputGradients = node.gradient(dys);
var _loop_2 = function(inputName2) {
if (!(inputName2 in inputGradients)) {
throw new Error("Cannot backprop through input " + inputName2 + ". " + ("Available gradients found: " + Object.keys(inputGradients) + "."));
}
var dx = tidy2(function() {
return inputGradients[inputName2]();
});
if (dx.dtype !== "float32") {
throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + (inputName2 + " must have 'float32' dtype, but has '" + dx.dtype + "'"));
}
var x = node.inputs[inputName2];
if (!arraysEqual(dx.shape, x.shape)) {
throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + ("'" + inputName2 + "' has shape '" + dx.shape + "', which does not match ") + ("the shape of the input '" + x.shape + "'"));
}
if (tensorAccumulatedGradientMap[x.id] == null) {
tensorAccumulatedGradientMap[x.id] = dx;
} else {
var curGradient = tensorAccumulatedGradientMap[x.id];
tensorAccumulatedGradientMap[x.id] = add2(curGradient, dx);
curGradient.dispose();
}
};
for (var inputName in node.inputs) {
_loop_2(inputName);
}
};
for (var i = filteredTape.length - 1; i >= 0; i--) {
_loop_1(i);
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FORMAT_LIMIT_NUM_VALS = 20;
var FORMAT_NUM_FIRST_LAST_VALS = 3;
var FORMAT_NUM_SIG_DIGITS = 7;
function tensorToString(vals, shape, dtype, verbose) {
var strides = computeStrides(shape);
var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
var rank = shape.length;
var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
var lines = ["Tensor"];
if (verbose) {
lines.push(" dtype: " + dtype);
lines.push(" rank: " + rank);
lines.push(" shape: [" + shape + "]");
lines.push(" values:");
}
lines.push(valsLines.map(function(l) {
return " " + l;
}).join("\n"));
return lines.join("\n");
}
function computeMaxSizePerColumn(vals, shape, dtype, strides) {
var n = sizeFromShape(shape);
var numCols = strides[strides.length - 1];
var padPerCol = new Array(numCols).fill(0);
var rank = shape.length;
var valuesOrTuples = dtype === "complex64" ? createComplexTuples(vals) : vals;
if (rank > 1) {
for (var row = 0; row < n / numCols; row++) {
var offset = row * numCols;
for (var j = 0; j < numCols; j++) {
padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
}
}
}
return padPerCol;
}
function valToString(val, pad2, dtype) {
var valStr;
if (Array.isArray(val)) {
valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " + (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j");
} else if (isString(val)) {
valStr = "'" + val + "'";
} else if (dtype === "bool") {
valStr = boolNumToString(val);
} else {
valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
}
return rightPad(valStr, pad2);
}
function boolNumToString(v) {
return v === 0 ? "false" : "true";
}
function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) {
if (isLast === void 0) {
isLast = true;
}
var storagePerElement = dtype === "complex64" ? 2 : 1;
var size = shape[0];
var rank = shape.length;
if (rank === 0) {
if (dtype === "complex64") {
var complexTuple = createComplexTuples(vals);
return [valToString(complexTuple[0], 0, dtype)];
}
if (dtype === "bool") {
return [boolNumToString(vals[0])];
}
return [vals[0].toString()];
}
if (rank === 1) {
if (size > FORMAT_LIMIT_NUM_VALS) {
var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
var firstVals = Array.from(vals.slice(0, firstValsSize));
var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
if (dtype === "complex64") {
firstVals = createComplexTuples(firstVals);
lastVals = createComplexTuples(lastVals);
}
return [
"[" + firstVals.map(function(x, i2) {
return valToString(x, padPerCol[i2], dtype);
}).join(", ") + ", ..., " + lastVals.map(function(x, i2) {
return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i2], dtype);
}).join(", ") + "]"
];
}
var displayVals = dtype === "complex64" ? createComplexTuples(vals) : Array.from(vals);
return [
"[" + displayVals.map(function(x, i2) {
return valToString(x, padPerCol[i2], dtype);
}).join(", ") + "]"
];
}
var subshape = shape.slice(1);
var substrides = strides.slice(1);
var stride = strides[0] * storagePerElement;
var lines = [];
if (size > FORMAT_LIMIT_NUM_VALS) {
for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
var start = i * stride;
var end = start + stride;
lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false));
}
lines.push("...");
for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
var start = i * stride;
var end = start + stride;
lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1));
}
} else {
for (var i = 0; i < size; i++) {
var start = i * stride;
var end = start + stride;
lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1));
}
}
var sep = rank === 2 ? "," : "";
lines[0] = "[" + lines[0] + sep;
for (var i = 1; i < lines.length - 1; i++) {
lines[i] = " " + lines[i] + sep;
}
var newLineSep = ",\n";
for (var i = 2; i < rank; i++) {
newLineSep += "\n";
}
lines[lines.length - 1] = " " + lines[lines.length - 1] + "]" + (isLast ? "" : newLineSep);
return lines;
}
function createComplexTuples(vals) {
var complexTuples = [];
for (var i = 0; i < vals.length; i += 2) {
complexTuples.push([vals[i], vals[i + 1]]);
}
return complexTuples;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TensorBuffer = function() {
function TensorBuffer2(shape, dtype, values) {
var _this = this;
this.dtype = dtype;
this.shape = shape.slice();
this.size = sizeFromShape(shape);
if (values != null) {
var n_1 = values.length;
assert(n_1 === this.size, function() {
return "Length of values '" + n_1 + "' does not match the size " + ("inferred by the shape '" + _this.size + "'.");
});
}
if (dtype === "complex64") {
throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).");
}
this.values = values || getArrayFromDType(dtype, this.size);
this.strides = computeStrides(shape);
}
TensorBuffer2.prototype.set = function(value) {
var _this = this;
var locs = [];
for (var _i2 = 1; _i2 < arguments.length; _i2++) {
locs[_i2 - 1] = arguments[_i2];
}
if (locs.length === 0) {
locs = [0];
}
assert(locs.length === this.rank, function() {
return "The number of provided coordinates (" + locs.length + ") must " + ("match the rank (" + _this.rank + ")");
});
var index = this.locToIndex(locs);
this.values[index] = value;
};
TensorBuffer2.prototype.get = function() {
var locs = [];
for (var _i2 = 0; _i2 < arguments.length; _i2++) {
locs[_i2] = arguments[_i2];
}
if (locs.length === 0) {
locs = [0];
}
var i = 0;
for (var _a = 0, locs_1 = locs; _a < locs_1.length; _a++) {
var loc = locs_1[_a];
if (loc < 0 || loc >= this.shape[i]) {
var msg = "Requested out of range element at " + locs + ". " + (" Buffer shape=" + this.shape);
throw new Error(msg);
}
i++;
}
var index = locs[locs.length - 1];
for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) {
index += this.strides[i_1] * locs[i_1];
}
return this.values[index];
};
TensorBuffer2.prototype.locToIndex = function(locs) {
if (this.rank === 0) {
return 0;
} else if (this.rank === 1) {
return locs[0];
}
var index = locs[locs.length - 1];
for (var i = 0; i < locs.length - 1; ++i) {
index += this.strides[i] * locs[i];
}
return index;
};
TensorBuffer2.prototype.indexToLoc = function(index) {
if (this.rank === 0) {
return [];
} else if (this.rank === 1) {
return [index];
}
var locs = new Array(this.shape.length);
for (var i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / this.strides[i]);
index -= locs[i] * this.strides[i];
}
locs[locs.length - 1] = index;
return locs;
};
Object.defineProperty(TensorBuffer2.prototype, "rank", {
get: function() {
return this.shape.length;
},
enumerable: true,
configurable: true
});
TensorBuffer2.prototype.toTensor = function() {
return trackerFn().makeTensor(this.values, this.shape, this.dtype);
};
return TensorBuffer2;
}();
var trackerFn = null;
var opHandler = null;
function setTensorTracker(fn) {
trackerFn = fn;
}
function setOpHandler(handler) {
opHandler = handler;
}
var Tensor = function() {
function Tensor2(shape, dtype, dataId, id) {
this.kept = false;
this.isDisposedInternal = false;
this.shape = shape.slice();
this.dtype = dtype || "float32";
this.size = sizeFromShape(shape);
this.strides = computeStrides(shape);
this.dataId = dataId;
this.id = id;
this.rankType = this.rank < 5 ? this.rank.toString() : "higher";
}
Object.defineProperty(Tensor2.prototype, "rank", {
get: function() {
return this.shape.length;
},
enumerable: true,
configurable: true
});
Tensor2.prototype.buffer = function() {
return __awaiter(this, void 0, void 0, function() {
var vals;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.data()];
case 1:
vals = _a.sent();
return [2, opHandler.buffer(this.shape, this.dtype, vals)];
}
});
});
};
Tensor2.prototype.bufferSync = function() {
return opHandler.buffer(this.shape, this.dtype, this.dataSync());
};
Tensor2.prototype.array = function() {
return __awaiter(this, void 0, void 0, function() {
var vals;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.data()];
case 1:
vals = _a.sent();
return [2, toNestedArray(this.shape, vals)];
}
});
});
};
Tensor2.prototype.arraySync = function() {
return toNestedArray(this.shape, this.dataSync());
};
Tensor2.prototype.data = function() {
return __awaiter(this, void 0, void 0, function() {
var data, bytes;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
this.throwIfDisposed();
data = trackerFn().read(this.dataId);
if (!(this.dtype === "string"))
return [3, 2];
return [4, data];
case 1:
bytes = _a.sent();
try {
return [2, bytes.map(function(b) {
return decodeString(b);
})];
} catch (_b) {
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
}
_a.label = 2;
case 2:
return [2, data];
}
});
});
};
Tensor2.prototype.dataSync = function() {
this.throwIfDisposed();
var data = trackerFn().readSync(this.dataId);
if (this.dtype === "string") {
try {
return data.map(function(b) {
return decodeString(b);
});
} catch (_a) {
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
}
}
return data;
};
Tensor2.prototype.bytes = function() {
return __awaiter(this, void 0, void 0, function() {
var data;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
this.throwIfDisposed();
return [4, trackerFn().read(this.dataId)];
case 1:
data = _a.sent();
if (this.dtype === "string") {
return [2, data];
} else {
return [2, new Uint8Array(data.buffer)];
}
}
});
});
};
Tensor2.prototype.dispose = function() {
if (this.isDisposed) {
return;
}
trackerFn().disposeTensor(this);
this.isDisposedInternal = true;
};
Object.defineProperty(Tensor2.prototype, "isDisposed", {
get: function() {
return this.isDisposedInternal;
},
enumerable: true,
configurable: true
});
Tensor2.prototype.throwIfDisposed = function() {
if (this.isDisposed) {
throw new Error("Tensor is disposed.");
}
};
Tensor2.prototype.print = function(verbose) {
if (verbose === void 0) {
verbose = false;
}
return opHandler.print(this, verbose);
};
Tensor2.prototype.clone = function() {
this.throwIfDisposed();
return opHandler.clone(this);
};
Tensor2.prototype.toString = function(verbose) {
if (verbose === void 0) {
verbose = false;
}
var vals = this.dataSync();
return tensorToString(vals, this.shape, this.dtype, verbose);
};
Tensor2.prototype.cast = function(dtype) {
this.throwIfDisposed();
return opHandler.cast(this, dtype);
};
Tensor2.prototype.variable = function(trainable, name, dtype) {
if (trainable === void 0) {
trainable = true;
}
this.throwIfDisposed();
return trackerFn().makeVariable(this, trainable, name, dtype);
};
return Tensor2;
}();
Object.defineProperty(Tensor, Symbol.hasInstance, {
value: function(instance) {
return !!instance && instance.data != null && instance.dataSync != null && instance.throwIfDisposed != null;
}
});
var Variable = function(_super) {
__extends(Variable2, _super);
function Variable2(initialValue, trainable, name, tensorId) {
var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this;
_this.trainable = trainable;
_this.name = name;
return _this;
}
Variable2.prototype.assign = function(newValue) {
if (newValue.dtype !== this.dtype) {
throw new Error("dtype of the new value (" + newValue.dtype + ") and " + ("previous value (" + this.dtype + ") must match"));
}
if (!arraysEqual(newValue.shape, this.shape)) {
throw new Error("shape of the new value (" + newValue.shape + ") and " + ("previous value (" + this.shape + ") must match"));
}
trackerFn().disposeTensor(this);
this.dataId = newValue.dataId;
trackerFn().incRef(this, null);
};
Variable2.prototype.dispose = function() {
trackerFn().disposeVariable(this);
this.isDisposedInternal = true;
};
return Variable2;
}(Tensor);
Object.defineProperty(Variable, Symbol.hasInstance, {
value: function(instance) {
return instance instanceof Tensor && instance.assign != null && instance.assign instanceof Function;
}
});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
(function(Rank) {
Rank["R0"] = "R0";
Rank["R1"] = "R1";
Rank["R2"] = "R2";
Rank["R3"] = "R3";
Rank["R4"] = "R4";
Rank["R5"] = "R5";
Rank["R6"] = "R6";
})(exports.Rank || (exports.Rank = {}));
var UpcastInt32AndMap;
(function(UpcastInt32AndMap2) {
UpcastInt32AndMap2["float32"] = "float32";
UpcastInt32AndMap2["int32"] = "int32";
UpcastInt32AndMap2["bool"] = "int32";
UpcastInt32AndMap2["complex64"] = "complex64";
})(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
var UpcastBoolAndMap;
(function(UpcastBoolAndMap2) {
UpcastBoolAndMap2["float32"] = "float32";
UpcastBoolAndMap2["int32"] = "int32";
UpcastBoolAndMap2["bool"] = "bool";
UpcastBoolAndMap2["complex64"] = "complex64";
})(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
var UpcastFloat32AndMap;
(function(UpcastFloat32AndMap2) {
UpcastFloat32AndMap2["float32"] = "float32";
UpcastFloat32AndMap2["int32"] = "float32";
UpcastFloat32AndMap2["bool"] = "float32";
UpcastFloat32AndMap2["complex64"] = "complex64";
})(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
var UpcastComplex64AndMap;
(function(UpcastComplex64AndMap2) {
UpcastComplex64AndMap2["float32"] = "complex64";
UpcastComplex64AndMap2["int32"] = "complex64";
UpcastComplex64AndMap2["bool"] = "complex64";
UpcastComplex64AndMap2["complex64"] = "complex64";
})(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
var upcastTypeMap = {
float32: UpcastFloat32AndMap,
int32: UpcastInt32AndMap,
bool: UpcastBoolAndMap,
complex64: UpcastComplex64AndMap
};
function upcastType(typeA, typeB) {
if (typeA === "string" || typeB === "string") {
if (typeA === "string" && typeB === "string") {
return "string";
}
throw new Error("Can not upcast " + typeA + " with " + typeB);
}
return upcastTypeMap[typeA][typeB];
}
function sumOutType(type) {
return upcastType(type, "int32");
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function makeTypesMatch(a, b) {
if (a.dtype === b.dtype) {
return [a, b];
}
var dtype = upcastType(a.dtype, b.dtype);
return [a.cast(dtype), b.cast(dtype)];
}
function assertTypesMatch(a, b) {
assert(a.dtype === b.dtype, function() {
return "The dtypes of the first(" + a.dtype + ") and" + (" second(" + b.dtype + ") input must match");
});
}
function isTensorInList(tensor2, tensorList) {
return tensorList.some(function(x) {
return x.id === tensor2.id;
});
}
function getTensorsInContainer(result) {
var list = [];
var seen = new Set();
walkTensorContainer(result, list, seen);
return list;
}
function walkTensorContainer(container, list, seen) {
if (container == null) {
return;
}
if (container instanceof Tensor) {
list.push(container);
return;
}
if (!isIterable(container)) {
return;
}
var iterable = container;
for (var k in iterable) {
var val = iterable[k];
if (!seen.has(val)) {
seen.add(val);
walkTensorContainer(val, list, seen);
}
}
}
function isIterable(obj) {
return Array.isArray(obj) || typeof obj === "object";
}
var tensor_util = {
__proto__: null,
makeTypesMatch,
assertTypesMatch,
isTensorInList,
getTensorsInContainer
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EngineState = function() {
function EngineState2() {
this.registeredVariables = {};
this.nextTapeNodeId = 0;
this.numBytes = 0;
this.numTensors = 0;
this.numStringTensors = 0;
this.numDataBuffers = 0;
this.gradientDepth = 0;
this.kernelDepth = 0;
this.scopeStack = [];
this.numDataMovesStack = [];
this.nextScopeId = 0;
this.tensorInfo = new WeakMap();
this.profiling = false;
this.activeProfile = {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};
}
EngineState2.prototype.dispose = function() {
for (var variableName in this.registeredVariables) {
this.registeredVariables[variableName].dispose();
}
};
return EngineState2;
}();
var Engine = function() {
function Engine2(ENV2) {
this.ENV = ENV2;
this.registry = {};
this.registryFactory = {};
this.pendingBackendInitId = 0;
this.state = new EngineState();
}
Engine2.prototype.ready = function() {
return __awaiter(this, void 0, void 0, function() {
var sortedBackends, i, backendName, success;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (this.pendingBackendInit != null) {
return [2, this.pendingBackendInit.then(function() {
})];
}
if (this.backendInstance != null) {
return [2];
}
sortedBackends = this.getSortedBackends();
i = 0;
_a.label = 1;
case 1:
if (!(i < sortedBackends.length))
return [3, 5];
backendName = sortedBackends[i];
return [4, this.initializeBackend(backendName).success];
case 2:
success = _a.sent();
if (!success)
return [3, 4];
return [4, this.setBackend(backendName)];
case 3:
_a.sent();
return [2];
case 4:
i++;
return [3, 1];
case 5:
throw new Error("Could not initialize any backends, all backend initializations failed.");
}
});
});
};
Object.defineProperty(Engine2.prototype, "backend", {
get: function() {
if (this.pendingBackendInit != null) {
throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods");
}
if (this.backendInstance == null) {
var _a = this.initializeBackendsAndReturnBest(), name_1 = _a.name, asyncInit = _a.asyncInit;
if (asyncInit) {
throw new Error("The highest priority backend '" + name_1 + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods");
}
this.setBackend(name_1);
}
return this.backendInstance;
},
enumerable: true,
configurable: true
});
Engine2.prototype.backendNames = function() {
return Object.keys(this.registryFactory);
};
Engine2.prototype.findBackend = function(backendName) {
if (!(backendName in this.registry)) {
if (backendName in this.registryFactory) {
var asyncInit = this.initializeBackend(backendName).asyncInit;
if (asyncInit) {
return null;
}
} else {
return null;
}
}
return this.registry[backendName];
};
Engine2.prototype.findBackendFactory = function(backendName) {
if (!(backendName in this.registryFactory)) {
return null;
}
return this.registryFactory[backendName].factory;
};
Engine2.prototype.registerBackend = function(backendName, factory, priority) {
if (priority === void 0) {
priority = 1;
}
if (backendName in this.registryFactory) {
console.warn(backendName + " backend was already registered. Reusing existing backend factory.");
return false;
}
this.registryFactory[backendName] = {factory, priority};
return true;
};
Engine2.prototype.setBackend = function(backendName) {
return __awaiter(this, void 0, void 0, function() {
var _a, success, asyncInit, result, _b;
return __generator(this, function(_c) {
switch (_c.label) {
case 0:
if (this.registryFactory[backendName] == null) {
throw new Error("Backend name '" + backendName + "' not found in registry");
}
this.backendName = backendName;
if (!(this.registry[backendName] == null))
return [3, 4];
this.backendInstance = null;
_a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit;
if (!asyncInit)
return [3, 2];
return [4, success];
case 1:
_b = _c.sent();
return [3, 3];
case 2:
_b = success;
_c.label = 3;
case 3:
result = _b;
if (!result) {
return [2, false];
}
_c.label = 4;
case 4:
this.backendInstance = this.registry[backendName];
this.setupRegisteredKernels();
this.profiler = new Profiler(this.backendInstance);
return [2, true];
}
});
});
};
Engine2.prototype.setupRegisteredKernels = function() {
var _this = this;
var kernels = getKernelsForBackend(this.backendName);
kernels.forEach(function(kernel) {
if (kernel.setupFunc != null) {
kernel.setupFunc(_this.backendInstance);
}
});
};
Engine2.prototype.disposeRegisteredKernels = function(backendName) {
var _this = this;
var kernels = getKernelsForBackend(backendName);
kernels.forEach(function(kernel) {
if (kernel.disposeFunc != null) {
kernel.disposeFunc(_this.registry[backendName]);
}
});
};
Engine2.prototype.initializeBackend = function(backendName) {
var _this = this;
var registryFactoryEntry = this.registryFactory[backendName];
if (registryFactoryEntry == null) {
throw new Error("Cannot initialize backend " + backendName + ", no registration found.");
}
try {
var backend2 = registryFactoryEntry.factory();
if (backend2 && !(backend2 instanceof KernelBackend) && typeof backend2.then === "function") {
var promiseId_1 = ++this.pendingBackendInitId;
var success = backend2.then(function(backendInstance) {
if (promiseId_1 < _this.pendingBackendInitId) {
return false;
}
_this.registry[backendName] = backendInstance;
_this.pendingBackendInit = null;
return true;
}).catch(function(err) {
if (promiseId_1 < _this.pendingBackendInitId) {
return false;
}
_this.pendingBackendInit = null;
console.warn("Initialization of backend " + backendName + " failed");
console.warn(err.stack || err.message);
return false;
});
this.pendingBackendInit = success;
return {success, asyncInit: true};
} else {
this.registry[backendName] = backend2;
return {success: true, asyncInit: false};
}
} catch (err) {
console.warn("Initialization of backend " + backendName + " failed");
console.warn(err.stack || err.message);
return {success: false, asyncInit: false};
}
};
Engine2.prototype.removeBackend = function(backendName) {
if (!(backendName in this.registryFactory)) {
throw new Error(backendName + " backend not found in registry");
}
if (this.backendName === backendName && this.pendingBackendInit != null) {
this.pendingBackendInitId++;
}
if (backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
delete this.registryFactory[backendName];
if (this.backendName === backendName) {
this.pendingBackendInit = null;
this.backendName = null;
this.backendInstance = null;
}
};
Engine2.prototype.getSortedBackends = function() {
var _this = this;
if (Object.keys(this.registryFactory).length === 0) {
throw new Error("No backend found in registry.");
}
return Object.keys(this.registryFactory).sort(function(a, b) {
return _this.registryFactory[b].priority - _this.registryFactory[a].priority;
});
};
Engine2.prototype.initializeBackendsAndReturnBest = function() {
var sortedBackends = this.getSortedBackends();
for (var i = 0; i < sortedBackends.length; i++) {
var backendName = sortedBackends[i];
var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit;
if (asyncInit || success) {
return {name: backendName, asyncInit};
}
}
throw new Error("Could not initialize any backends, all backend initializations failed.");
};
Engine2.prototype.moveData = function(backend2, dataId) {
var info = this.state.tensorInfo.get(dataId);
var srcBackend = info.backend;
var values = this.readSync(dataId);
srcBackend.disposeData(dataId);
info.backend = backend2;
backend2.move(dataId, values, info.shape, info.dtype);
if (this.shouldCheckForMemLeaks()) {
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
}
};
Engine2.prototype.tidy = function(nameOrFn, fn) {
var _this = this;
var name = null;
if (fn == null) {
if (typeof nameOrFn !== "function") {
throw new Error("Please provide a function to tidy()");
}
fn = nameOrFn;
} else {
if (typeof nameOrFn !== "string" && !(nameOrFn instanceof String)) {
throw new Error("When calling with two arguments, the first argument to tidy() must be a string");
}
if (typeof fn !== "function") {
throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function");
}
name = nameOrFn;
}
var result;
return this.scopedRun(function() {
return _this.startScope(name);
}, function() {
return _this.endScope(result);
}, function() {
result = fn();
if (result instanceof Promise) {
console.error("Cannot return a Promise inside of tidy.");
}
return result;
});
};
Engine2.prototype.scopedRun = function(start, end, f) {
start();
try {
var res = f();
end();
return res;
} catch (ex) {
end();
throw ex;
}
};
Engine2.prototype.nextTensorId = function() {
return Engine2.nextTensorId++;
};
Engine2.prototype.nextVariableId = function() {
return Engine2.nextVariableId++;
};
Engine2.prototype.clone = function(x) {
var y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);
var inputs = {x};
var grad2 = function(dy) {
return {
x: function() {
var dtype = "float32";
var gradInputs = {x: dy};
var attrs = {dtype};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.cast(dy, dtype);
}, gradInputs, null, Cast, attrs);
}
};
};
var saved = [];
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad2, saved, {});
return y;
};
Engine2.prototype.runKernel = function(kernelName, inputs, attrs, inputsToSave, outputsToSave) {
var forwardFunc = null;
var backwardsFunc = null;
return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave);
};
Engine2.prototype.shouldCheckForMemLeaks = function() {
return this.ENV.getBool("IS_TEST");
};
Engine2.prototype.checkKernelForMemLeak = function(kernelName, numDataIdsBefore, outInfos) {
var numDataIdsAfter = this.backend.numDataIds();
var numOutputDataIds = 0;
outInfos.forEach(function(info) {
numOutputDataIds += info.dtype === "complex64" ? 3 : 1;
});
var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
if (dataIdsLeaked > 0) {
throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'"));
}
};
Engine2.prototype.runKernelFunc = function(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) {
var _this = this;
var outputs;
var saved = [];
var isTapeOn = this.isTapeOn();
if (kernelName == null) {
kernelName = this.state.activeScope != null ? this.state.activeScope.name : "";
}
var startingBytecount = this.state.numBytes;
var startingNumTensors = this.state.numTensors;
if (this.shouldCheckForMemLeaks()) {
this.state.numDataMovesStack.push(0);
}
var kernelFunc;
var kernel = getKernel(kernelName, this.backendName);
var out;
if (kernel != null) {
kernelFunc = function() {
var numDataIdsBefore = _this.backend.numDataIds();
out = kernel.kernelFunc({inputs, attrs, backend: _this.backend});
var outInfos = Array.isArray(out) ? out : [out];
if (_this.shouldCheckForMemLeaks()) {
_this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
}
var outTensors = outInfos.map(function(_a) {
var dataId = _a.dataId, shape = _a.shape, dtype = _a.dtype;
return _this.makeTensorFromDataId(dataId, shape, dtype);
});
if (isTapeOn) {
var tensorsToSave = _this.getTensorsForGradient(kernelName, inputs, outTensors);
if (tensorsToSave == null) {
if (outputsToSave == null) {
outputsToSave = [];
}
var outsToSave = outTensors.filter(function(_, i) {
return outputsToSave[i];
});
tensorsToSave = (inputsToSave || []).slice().concat(outsToSave);
}
saved = _this.saveTensorsForBackwardMode(tensorsToSave);
}
return outTensors;
};
} else {
var saveFunc_1 = function(tensors) {
if (!isTapeOn) {
return;
}
saved = tensors.map(function(tensor2) {
return _this.keep(_this.clone(tensor2));
});
};
kernelFunc = function() {
var numDataIdsBefore = _this.backend.numDataIds();
out = _this.tidy(function() {
return forwardFunc(_this.backend, saveFunc_1);
});
var outs = Array.isArray(out) ? out : [out];
if (_this.shouldCheckForMemLeaks()) {
_this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);
}
return outs;
};
}
var kernelProfile;
this.scopedRun(function() {
return _this.state.kernelDepth++;
}, function() {
return _this.state.kernelDepth--;
}, function() {
if (!_this.ENV.getBool("DEBUG") && !_this.state.profiling) {
outputs = kernelFunc();
} else {
kernelProfile = _this.profiler.profileKernel(kernelName, inputs, function() {
return kernelFunc();
});
if (_this.ENV.getBool("DEBUG")) {
_this.profiler.logKernelProfile(kernelProfile);
}
outputs = kernelProfile.outputs;
}
});
if (isTapeOn) {
this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs);
}
if (this.state.profiling) {
this.state.activeProfile.kernels.push({
name: kernelName,
bytesAdded: this.state.numBytes - startingBytecount,
totalBytesSnapshot: this.state.numBytes,
tensorsAdded: this.state.numTensors - startingNumTensors,
totalTensorsSnapshot: this.state.numTensors,
inputShapes: Object.keys(inputs).map(function(key) {
return inputs[key] != null ? inputs[key].shape : null;
}),
outputShapes: outputs.map(function(item) {
return item.shape;
}),
kernelTimeMs: kernelProfile.timeMs,
extraInfo: kernelProfile.extraInfo
});
}
return Array.isArray(out) ? outputs : outputs[0];
};
Engine2.prototype.saveTensorsForBackwardMode = function(tensors) {
var _this = this;
var saved = tensors.map(function(tensor2) {
return _this.keep(_this.clone(tensor2));
});
return saved;
};
Engine2.prototype.getTensorsForGradient = function(kernelName, inputs, outputs) {
var gradConfig = getGradient(kernelName);
if (gradConfig != null) {
var inputsToSave = gradConfig.inputsToSave || [];
var outputsToSave_1 = gradConfig.outputsToSave || [];
var inputTensorsToSave = void 0;
if (gradConfig.saveAllInputs) {
assert(Array.isArray(inputs), function() {
return "saveAllInputs is true, expected inputs to be an array.";
});
inputTensorsToSave = Object.keys(inputs).map(function(key) {
return inputs[key];
});
} else {
inputTensorsToSave = inputsToSave.map(function(inputName) {
return inputs[inputName];
});
}
var outputTensorsToSave = outputs.filter(function(_, i) {
return outputsToSave_1[i];
});
return inputTensorsToSave.concat(outputTensorsToSave);
}
return null;
};
Engine2.prototype.makeTensor = function(values, shape, dtype, backend2) {
if (values == null) {
throw new Error("Values passed to engine.makeTensor() are null");
}
dtype = dtype || "float32";
backend2 = backend2 || this.backend;
var backendVals = values;
if (dtype === "string" && isString(values[0])) {
backendVals = values.map(function(d) {
return encodeString(d);
});
}
var dataId = backend2.write(backendVals, shape, dtype);
var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.incRef(t, backend2);
if (dtype === "string") {
var info = this.state.tensorInfo.get(dataId);
var newBytes = bytesFromStringArray(backendVals);
this.state.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}
return t;
};
Engine2.prototype.makeTensorFromDataId = function(dataId, shape, dtype, backend2) {
dtype = dtype || "float32";
var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.incRef(t, backend2);
return t;
};
Engine2.prototype.makeVariable = function(initialValue, trainable, name, dtype) {
if (trainable === void 0) {
trainable = true;
}
name = name || this.nextVariableId().toString();
if (dtype != null && dtype !== initialValue.dtype) {
initialValue = initialValue.cast(dtype);
}
var v = new Variable(initialValue, trainable, name, this.nextTensorId());
if (this.state.registeredVariables[v.name] != null) {
throw new Error("Variable with name " + v.name + " was already registered");
}
this.state.registeredVariables[v.name] = v;
this.incRef(v, this.backend);
return v;
};
Engine2.prototype.incRef = function(a, backend2) {
var refCount = this.state.tensorInfo.has(a.dataId) ? this.state.tensorInfo.get(a.dataId).refCount : 0;
this.state.numTensors++;
if (a.dtype === "string") {
this.state.numStringTensors++;
}
if (refCount === 0) {
this.state.numDataBuffers++;
var bytes = 0;
if (a.dtype !== "complex64" && a.dtype !== "string") {
bytes = a.size * bytesPerElement(a.dtype);
}
this.state.tensorInfo.set(a.dataId, {
backend: backend2 || this.backend,
dtype: a.dtype,
shape: a.shape,
bytes,
refCount: 0
});
this.state.numBytes += bytes;
}
this.state.tensorInfo.get(a.dataId).refCount++;
if (!(a instanceof Variable)) {
this.track(a);
}
};
Engine2.prototype.disposeTensor = function(a) {
if (!this.state.tensorInfo.has(a.dataId)) {
return;
}
this.state.numTensors--;
if (a.dtype === "string") {
this.state.numStringTensors--;
}
var info = this.state.tensorInfo.get(a.dataId);
var refCount = info.refCount;
if (refCount <= 1) {
if (a.dtype !== "complex64") {
this.state.numBytes -= info.bytes;
}
this.state.numDataBuffers--;
info.backend.disposeData(a.dataId);
this.state.tensorInfo.delete(a.dataId);
} else {
this.state.tensorInfo.get(a.dataId).refCount--;
}
};
Engine2.prototype.disposeVariables = function() {
for (var varName in this.state.registeredVariables) {
var v = this.state.registeredVariables[varName];
this.disposeVariable(v);
}
};
Engine2.prototype.disposeVariable = function(v) {
this.disposeTensor(v);
if (this.state.registeredVariables[v.name] != null) {
delete this.state.registeredVariables[v.name];
}
};
Engine2.prototype.memory = function() {
var info = this.backend.memory();
info.numTensors = this.state.numTensors;
info.numDataBuffers = this.state.numDataBuffers;
info.numBytes = this.state.numBytes;
if (this.state.numStringTensors > 0) {
info.unreliable = true;
if (info.reasons == null) {
info.reasons = [];
}
info.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)");
}
return info;
};
Engine2.prototype.profile = function(query) {
return __awaiter(this, void 0, void 0, function() {
var startBytes, startNumTensors, _a, _i2, _b, kernel, _c, _d;
return __generator(this, function(_e) {
switch (_e.label) {
case 0:
this.state.profiling = true;
startBytes = this.state.numBytes;
startNumTensors = this.state.numTensors;
this.state.activeProfile.kernels = [];
_a = this.state.activeProfile;
return [4, query()];
case 1:
_a.result = _e.sent();
this.state.profiling = false;
this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function(d) {
return d.totalBytesSnapshot;
}));
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors;
_i2 = 0, _b = this.state.activeProfile.kernels;
_e.label = 2;
case 2:
if (!(_i2 < _b.length))
return [3, 6];
kernel = _b[_i2];
_c = kernel;
return [4, kernel.kernelTimeMs];
case 3:
_c.kernelTimeMs = _e.sent();
_d = kernel;
return [4, kernel.extraInfo];
case 4:
_d.extraInfo = _e.sent();
_e.label = 5;
case 5:
_i2++;
return [3, 2];
case 6:
return [2, this.state.activeProfile];
}
});
});
};
Engine2.prototype.isTapeOn = function() {
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
};
Engine2.prototype.addTapeNode = function(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
var _this = this;
var tapeNode = {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};
var gradConfig = getGradient(kernelName);
if (gradConfig != null) {
gradientsFunc = gradConfig.gradFunc;
}
if (gradientsFunc != null) {
tapeNode.gradient = function(dys) {
dys = dys.map(function(dy, i) {
if (dy == null) {
var output = outputs[i];
var vals = makeZerosTypedArray(output.size, output.dtype);
return _this.makeTensor(vals, output.shape, output.dtype);
}
return dy;
});
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
};
}
this.state.activeTape.push(tapeNode);
};
Engine2.prototype.keep = function(result) {
result.kept = true;
return result;
};
Engine2.prototype.startTape = function() {
if (this.state.gradientDepth === 0) {
this.state.activeTape = [];
}
this.state.gradientDepth++;
};
Engine2.prototype.endTape = function() {
this.state.gradientDepth--;
};
Engine2.prototype.startScope = function(name) {
var scopeInfo = {
track: [],
name: "unnamed scope",
id: this.state.nextScopeId++
};
if (name) {
scopeInfo.name = name;
}
this.state.scopeStack.push(scopeInfo);
this.state.activeScope = scopeInfo;
};
Engine2.prototype.endScope = function(result) {
var _this = this;
var tensorsToTrackInParent = getTensorsInContainer(result);
var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function(t) {
return t.id;
}));
for (var i = 0; i < this.state.activeScope.track.length; i++) {
var tensor2 = this.state.activeScope.track[i];
if (!tensor2.kept && !tensorsToTrackInParentSet.has(tensor2.id)) {
tensor2.dispose();
}
}
var oldScope = this.state.scopeStack.pop();
this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1];
tensorsToTrackInParent.forEach(function(tensor3) {
if (!tensor3.kept && tensor3.scopeId === oldScope.id) {
_this.track(tensor3);
}
});
};
Engine2.prototype.gradients = function(f, xs, dy, allowNoGradients) {
var _this = this;
if (allowNoGradients === void 0) {
allowNoGradients = false;
}
assert(xs.length > 0, function() {
return "gradients() received an empty list of xs.";
});
if (dy != null && dy.dtype !== "float32") {
throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'");
}
var y = this.scopedRun(function() {
return _this.startTape();
}, function() {
return _this.endTape();
}, function() {
return _this.tidy("forward", f);
});
assert(y instanceof Tensor, function() {
return "The result y returned by f() must be a tensor.";
});
var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.");
}
return this.tidy("backward", function() {
var accumulatedGradientMap = {};
accumulatedGradientMap[y.id] = dy == null ? ones(y.shape) : dy;
backpropagateGradients(accumulatedGradientMap, filteredTape, function(f2) {
return _this.tidy(f2);
}, add);
var grads2 = xs.map(function(x) {
return accumulatedGradientMap[x.id];
});
if (_this.state.gradientDepth === 0) {
_this.state.activeTape.forEach(function(node) {
for (var _i2 = 0, _a = node.saved; _i2 < _a.length; _i2++) {
var tensor2 = _a[_i2];
tensor2.dispose();
}
});
_this.state.activeTape = null;
}
return {value: y, grads: grads2};
});
};
Engine2.prototype.customGrad = function(f) {
var _this = this;
assert(isFunction(f), function() {
return "The f passed in customGrad(f) must be a function.";
});
return function() {
var inputs = [];
for (var _i2 = 0; _i2 < arguments.length; _i2++) {
inputs[_i2] = arguments[_i2];
}
assert(inputs.every(function(t) {
return t instanceof Tensor;
}), function() {
return "The args passed in customGrad(f)(x1, x2,...) must all be tensors";
});
var res;
var inputMap = {};
inputs.forEach(function(input, i) {
inputMap[i] = input;
});
return _this.runKernelFunc(function(_, save) {
res = f.apply(void 0, inputs.concat([save]));
assert(res.value instanceof Tensor, function() {
return "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor";
});
assert(isFunction(res.gradFunc), function() {
return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function.";
});
return res.value;
}, inputMap, function(dy, saved) {
var gradRes = res.gradFunc(dy, saved);
var grads2 = Array.isArray(gradRes) ? gradRes : [gradRes];
assert(grads2.length === inputs.length, function() {
return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...).";
});
assert(grads2.every(function(t) {
return t instanceof Tensor;
}), function() {
return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors.";
});
var gradMap = {};
grads2.forEach(function(grad2, i) {
gradMap[i] = function() {
return grad2;
};
});
return gradMap;
});
};
};
Engine2.prototype.readSync = function(dataId) {
var info = this.state.tensorInfo.get(dataId);
return info.backend.readSync(dataId);
};
Engine2.prototype.read = function(dataId) {
var info = this.state.tensorInfo.get(dataId);
return info.backend.read(dataId);
};
Engine2.prototype.time = function(query) {
return __awaiter(this, void 0, void 0, function() {
var start, timingInfo;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
start = now();
return [4, this.backend.time(query)];
case 1:
timingInfo = _a.sent();
timingInfo.wallMs = now() - start;
return [2, timingInfo];
}
});
});
};
Engine2.prototype.track = function(result) {
if (this.state.activeScope != null) {
result.scopeId = this.state.activeScope.id;
this.state.activeScope.track.push(result);
}
return result;
};
Object.defineProperty(Engine2.prototype, "registeredVariables", {
get: function() {
return this.state.registeredVariables;
},
enumerable: true,
configurable: true
});
Engine2.prototype.reset = function() {
this.pendingBackendInitId++;
this.state.dispose();
this.ENV.reset();
this.state = new EngineState();
for (var backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
this.backendName = null;
this.backendInstance = null;
this.pendingBackendInit = null;
};
Engine2.nextTensorId = 0;
Engine2.nextVariableId = 0;
return Engine2;
}();
function ones(shape) {
var values = makeOnesTypedArray(sizeFromShape(shape), "float32");
return ENGINE.makeTensor(values, shape, "float32");
}
function getOrMakeEngine() {
var ns = getGlobalNamespace();
if (ns._tfengine == null) {
var environment = new Environment(ns);
ns._tfengine = new Engine(environment);
}
setEnvironmentGlobal(ns._tfengine.ENV);
setTensorTracker(function() {
return ns._tfengine;
});
return ns._tfengine;
}
var ENGINE = getOrMakeEngine();
function add(a, b) {
var inputs = {a, b};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.add(a, b);
save([a, b]);
return res;
}, inputs, null, Add);
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function _isNavigatorDefined() {
return typeof navigator !== "undefined" && navigator != null;
}
function isMobile() {
if (_isNavigatorDefined()) {
var a = navigator.userAgent || navigator.vendor || window.opera;
return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(a) || /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(a.substr(0, 4));
}
return false;
}
function isBrowser() {
return typeof window !== "undefined" && window.document != null || typeof WorkerGlobalScope !== "undefined";
}
var device_util = {
__proto__: null,
isMobile,
isBrowser
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ENV = env();
ENV.registerFlag("DEBUG", function() {
return false;
}, function(debugValue) {
if (debugValue) {
console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.");
}
});
ENV.registerFlag("IS_BROWSER", function() {
return isBrowser();
});
ENV.registerFlag("IS_NODE", function() {
return typeof process !== "undefined" && typeof process.versions !== "undefined" && typeof process.versions.node !== "undefined";
});
ENV.registerFlag("IS_CHROME", function() {
return typeof navigator !== "undefined" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor);
});
ENV.registerFlag("PROD", function() {
return false;
});
ENV.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", function() {
return ENV.getBool("DEBUG");
});
ENV.registerFlag("DEPRECATION_WARNINGS_ENABLED", function() {
return true;
});
ENV.registerFlag("IS_TEST", function() {
return false;
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function inferShape(val, dtype) {
var firstElem = val;
if (isTypedArray(val)) {
return dtype === "string" ? [] : [val.length];
}
if (!Array.isArray(val)) {
return [];
}
var shape = [];
while (Array.isArray(firstElem) || isTypedArray(firstElem) && dtype !== "string") {
shape.push(firstElem.length);
firstElem = firstElem[0];
}
if (Array.isArray(val) && env().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY")) {
deepAssertShapeConsistency(val, shape, []);
}
return shape;
}
function deepAssertShapeConsistency(val, shape, indices) {
indices = indices || [];
if (!Array.isArray(val) && !isTypedArray(val)) {
assert(shape.length === 0, function() {
return "Element arr[" + indices.join("][") + "] is a primitive, " + ("but should be an array/TypedArray of " + shape[0] + " elements");
});
return;
}
assert(shape.length > 0, function() {
return "Element arr[" + indices.join("][") + "] should be a primitive, " + ("but is an array of " + val.length + " elements");
});
assert(val.length === shape[0], function() {
return "Element arr[" + indices.join("][") + "] should have " + shape[0] + " " + ("elements, but has " + val.length + " elements");
});
var subShape = shape.slice(1);
for (var i = 0; i < val.length; ++i) {
deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
}
}
function assertDtype(expectedDtype, actualDType, argName, functionName) {
if (expectedDtype == null) {
return;
}
if (expectedDtype !== "numeric" && expectedDtype !== actualDType || expectedDtype === "numeric" && actualDType === "string") {
throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor"));
}
}
function convertToTensor(x, argName, functionName, parseAsDtype) {
if (parseAsDtype === void 0) {
parseAsDtype = "numeric";
}
if (x instanceof Tensor) {
assertDtype(parseAsDtype, x.dtype, argName, functionName);
return x;
}
var inferredDtype = inferDtype(x);
if (inferredDtype !== "string" && ["bool", "int32", "float32"].indexOf(parseAsDtype) >= 0) {
inferredDtype = parseAsDtype;
}
assertDtype(parseAsDtype, inferredDtype, argName, functionName);
if (x == null || !isTypedArray(x) && !Array.isArray(x) && typeof x !== "number" && typeof x !== "boolean" && typeof x !== "string") {
var type = x == null ? "null" : x.constructor.name;
throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + ("Tensor or TensorLike, but got '" + type + "'"));
}
var inferredShape = inferShape(x, inferredDtype);
if (!isTypedArray(x) && !Array.isArray(x)) {
x = [x];
}
var skipTypedArray = true;
var values = inferredDtype !== "string" ? toTypedArray(x, inferredDtype) : flatten(x, [], skipTypedArray);
return ENGINE.makeTensor(values, inferredShape, inferredDtype);
}
function convertToTensorArray(arg, argName, functionName, parseAsDtype) {
if (parseAsDtype === void 0) {
parseAsDtype = "numeric";
}
if (!Array.isArray(arg)) {
throw new Error("Argument " + argName + " passed to " + functionName + " must be a `Tensor[]` or `TensorLike[]`");
}
var tensors = arg;
return tensors.map(function(t, i) {
return convertToTensor(t, argName + "[" + i + "]", functionName);
}, parseAsDtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var OP_SCOPE_SUFFIX = "__op";
function op(f) {
var keys = Object.keys(f);
if (keys.length !== 1) {
throw new Error("Please provide an object with a single key (operation name) mapping to a function. Got an object with " + (keys.length + " keys."));
}
var opName = keys[0];
var fn = f[opName];
if (opName.endsWith("_")) {
opName = opName.substring(0, opName.length - 1);
}
opName = opName + OP_SCOPE_SUFFIX;
var f2 = function() {
var args = [];
for (var _i2 = 0; _i2 < arguments.length; _i2++) {
args[_i2] = arguments[_i2];
}
ENGINE.startScope(opName);
try {
var result = fn.apply(void 0, args);
if (isPromise(result)) {
console.error("Cannot return a Promise inside of tidy.");
}
ENGINE.endScope(result);
return result;
} catch (ex) {
ENGINE.endScope(null);
throw ex;
}
};
Object.defineProperty(f2, "name", {value: opName, configurable: true});
return f2;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function complex_(real2, imag2) {
var $real = convertToTensor(real2, "real", "complex");
var $imag = convertToTensor(imag2, "imag", "complex");
assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", must match in call to tf.complex().");
var forward = function(backend2) {
return backend2.complex($real, $imag);
};
var inputs = {real: $real, imag: $imag};
return ENGINE.runKernelFunc(forward, inputs, null, Complex);
}
var complex = op({complex_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function makeTensor(values, shape, inferredShape, dtype) {
if (dtype == null) {
dtype = inferDtype(values);
}
if (dtype === "complex64") {
throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");
}
if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== "number" && typeof values !== "boolean" && typeof values !== "string") {
throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");
}
if (shape != null) {
assertNonNegativeIntegerDimensions(shape);
var providedSize_1 = sizeFromShape(shape);
var inferredSize_1 = sizeFromShape(inferredShape);
assert(providedSize_1 === inferredSize_1, function() {
return "Based on the provided shape, [" + shape + "], the tensor should have " + (providedSize_1 + " values but has " + inferredSize_1);
});
for (var i = 0; i < inferredShape.length; ++i) {
var inferred = inferredShape[i];
var flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true;
assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function() {
return "Error creating a new Tensor. Inferred shape " + ("(" + inferredShape + ") does not match the provided ") + ("shape (" + shape + "). ");
});
}
}
if (!isTypedArray(values) && !Array.isArray(values)) {
values = [values];
}
shape = shape || inferredShape;
values = dtype !== "string" ? toTypedArray(values, dtype) : flatten(values, [], true);
return ENGINE.makeTensor(values, shape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tensor(values, shape, dtype) {
var inferredShape = inferShape(values, dtype);
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DTYPE_VALUE_SIZE_MAP = {
float32: 4,
float16: 2,
int32: 4,
uint16: 2,
uint8: 1,
bool: 1,
complex64: 8
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var NUM_BYTES_STRING_LENGTH = 4;
function encodeWeights(tensors, group) {
return __awaiter(this, void 0, void 0, function() {
var specs, dataPromises, names, _loop_1, i, tensorValues;
var _this = this;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
specs = [];
dataPromises = [];
names = Array.isArray(tensors) ? tensors.map(function(tensor2) {
return tensor2.name;
}) : Object.keys(tensors);
_loop_1 = function(i2) {
var name_1 = names[i2];
var t = Array.isArray(tensors) ? tensors[i2].tensor : tensors[name_1];
if (t.dtype !== "float32" && t.dtype !== "int32" && t.dtype !== "bool" && t.dtype !== "string" && t.dtype !== "complex64") {
throw new Error("Unsupported dtype in weight '" + name_1 + "': " + t.dtype);
}
var spec = {name: name_1, shape: t.shape, dtype: t.dtype};
if (t.dtype === "string") {
var utf8bytes = new Promise(function(resolve) {
return __awaiter(_this, void 0, void 0, function() {
var vals, totalNumBytes, bytes, offset, i_1, val, bytesOfLength;
return __generator(this, function(_a2) {
switch (_a2.label) {
case 0:
return [4, t.bytes()];
case 1:
vals = _a2.sent();
totalNumBytes = vals.reduce(function(p, c) {
return p + c.length;
}, 0) + NUM_BYTES_STRING_LENGTH * vals.length;
bytes = new Uint8Array(totalNumBytes);
offset = 0;
for (i_1 = 0; i_1 < vals.length; i_1++) {
val = vals[i_1];
bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
bytes.set(bytesOfLength, offset);
offset += NUM_BYTES_STRING_LENGTH;
bytes.set(val, offset);
offset += val.length;
}
resolve(bytes);
return [2];
}
});
});
});
dataPromises.push(utf8bytes);
} else {
dataPromises.push(t.data());
}
if (group != null) {
spec.group = group;
}
specs.push(spec);
};
for (i = 0; i < names.length; ++i) {
_loop_1(i);
}
return [4, Promise.all(dataPromises)];
case 1:
tensorValues = _a.sent();
return [2, {data: concatenateTypedArrays(tensorValues), specs}];
}
});
});
}
function decodeWeights(buffer2, specs) {
var out = {};
var float16Decode;
var offset = 0;
for (var _i2 = 0, specs_1 = specs; _i2 < specs_1.length; _i2++) {
var spec = specs_1[_i2];
var name_2 = spec.name;
var dtype = spec.dtype;
var shape = spec.shape;
var size = sizeFromShape(shape);
var values = void 0;
if ("quantization" in spec) {
var quantization = spec.quantization;
if (quantization.dtype === "uint8" || quantization.dtype === "uint16") {
if (!("min" in quantization && "scale" in quantization)) {
throw new Error("Weight " + spec.name + " with quantization " + quantization.dtype + " doesn't have corresponding metadata min and scale.");
}
} else if (quantization.dtype === "float16") {
if (dtype !== "float32") {
throw new Error("Weight " + spec.name + " is quantized with " + quantization.dtype + " " + ("which only supports weights of type float32 not " + dtype + "."));
}
} else {
throw new Error("Weight " + spec.name + " has unknown " + ("quantization dtype " + quantization.dtype + ". ") + "Supported quantization dtypes are: 'uint8', 'uint16', and 'float16'.");
}
var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
var byteBuffer = buffer2.slice(offset, offset + size * quantizationSizeFactor);
var quantizedArray = quantization.dtype === "uint8" ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer);
if (dtype === "float32") {
if (quantization.dtype === "uint8" || quantization.dtype === "uint16") {
values = new Float32Array(quantizedArray.length);
for (var i = 0; i < quantizedArray.length; i++) {
var v = quantizedArray[i];
values[i] = v * quantization.scale + quantization.min;
}
} else if (quantization.dtype === "float16") {
if (float16Decode === void 0) {
float16Decode = getFloat16Decoder();
}
values = float16Decode(quantizedArray);
} else {
throw new Error("Unsupported quantization type " + quantization.dtype + " for weight type float32.");
}
} else if (dtype === "int32") {
if (quantization.dtype !== "uint8" && quantization.dtype !== "uint16") {
throw new Error("Unsupported quantization type " + quantization.dtype + " for weight type int32.");
}
values = new Int32Array(quantizedArray.length);
for (var i = 0; i < quantizedArray.length; i++) {
var v = quantizedArray[i];
values[i] = Math.round(v * quantization.scale + quantization.min);
}
} else {
throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype);
}
offset += size * quantizationSizeFactor;
} else if (dtype === "string") {
var size_1 = sizeFromShape(spec.shape);
values = [];
for (var i = 0; i < size_1; i++) {
var byteLength = new Uint32Array(buffer2.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
offset += NUM_BYTES_STRING_LENGTH;
var bytes = new Uint8Array(buffer2.slice(offset, offset + byteLength));
values.push(bytes);
offset += byteLength;
}
} else {
var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
var byteBuffer = buffer2.slice(offset, offset + size * dtypeFactor);
if (dtype === "float32") {
values = new Float32Array(byteBuffer);
} else if (dtype === "int32") {
values = new Int32Array(byteBuffer);
} else if (dtype === "bool") {
values = new Uint8Array(byteBuffer);
} else if (dtype === "complex64") {
values = new Float32Array(byteBuffer);
var real2 = new Float32Array(values.length / 2);
var image2 = new Float32Array(values.length / 2);
for (var i = 0; i < real2.length; i++) {
real2[i] = values[i * 2];
image2[i] = values[i * 2 + 1];
}
var realTensor = tensor(real2, shape, "float32");
var imageTensor = tensor(image2, shape, "float32");
out[name_2] = complex(realTensor, imageTensor);
realTensor.dispose();
imageTensor.dispose();
} else {
throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype);
}
offset += size * dtypeFactor;
}
if (dtype !== "complex64") {
out[name_2] = tensor(values, shape, dtype);
}
}
return out;
}
function concatenateTypedArrays(xs) {
if (xs === null) {
throw new Error("Invalid input value: " + JSON.stringify(xs));
}
var totalByteLength = 0;
var normalizedXs = [];
xs.forEach(function(x) {
totalByteLength += x.byteLength;
normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : new x.constructor(x));
if (!(x instanceof Float32Array || x instanceof Int32Array || x instanceof Uint8Array)) {
throw new Error("Unsupported TypedArray subtype: " + x.constructor.name);
}
});
var y = new Uint8Array(totalByteLength);
var offset = 0;
normalizedXs.forEach(function(x) {
y.set(new Uint8Array(x.buffer), offset);
offset += x.byteLength;
});
return y.buffer;
}
var useNodeBuffer = typeof Buffer !== "undefined" && (typeof Blob === "undefined" || typeof atob === "undefined" || typeof btoa === "undefined");
function stringByteLength(str) {
if (useNodeBuffer) {
return Buffer.byteLength(str);
}
return new Blob([str]).size;
}
function arrayBufferToBase64String(buffer2) {
if (useNodeBuffer) {
return Buffer.from(buffer2).toString("base64");
}
var buf = new Uint8Array(buffer2);
var s = "";
for (var i = 0, l = buf.length; i < l; i++) {
s += String.fromCharCode(buf[i]);
}
return btoa(s);
}
function base64StringToArrayBuffer(str) {
if (useNodeBuffer) {
var buf = Buffer.from(str, "base64");
return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
}
var s = atob(str);
var buffer2 = new Uint8Array(s.length);
for (var i = 0; i < s.length; ++i) {
buffer2.set([s.charCodeAt(i)], i);
}
return buffer2.buffer;
}
function concatenateArrayBuffers(buffers) {
if (buffers.length === 1) {
return buffers[0];
}
var totalByteLength = 0;
buffers.forEach(function(buffer2) {
totalByteLength += buffer2.byteLength;
});
var temp = new Uint8Array(totalByteLength);
var offset = 0;
buffers.forEach(function(buffer2) {
temp.set(new Uint8Array(buffer2), offset);
offset += buffer2.byteLength;
});
return temp.buffer;
}
function basename(path) {
var SEPARATOR = "/";
path = path.trim();
while (path.endsWith(SEPARATOR)) {
path = path.slice(0, path.length - 1);
}
var items = path.split(SEPARATOR);
return items[items.length - 1];
}
function getModelArtifactsInfoForJSON(modelArtifacts) {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error("Expected JSON model topology, received ArrayBuffer.");
}
return {
dateSaved: new Date(),
modelTopologyType: "JSON",
modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
weightDataBytes: modelArtifacts.weightData == null ? 0 : modelArtifacts.weightData.byteLength
};
}
function computeFloat16MantisaTable() {
var convertMantissa = function(i2) {
var m = i2 << 13;
var e = 0;
while ((m & 8388608) === 0) {
e -= 8388608;
m <<= 1;
}
m &= ~8388608;
e += 947912704;
return m | e;
};
var mantisaTable = new Uint32Array(2048);
mantisaTable[0] = 0;
for (var i = 1; i < 1024; i++) {
mantisaTable[i] = convertMantissa(i);
}
for (var i = 1024; i < 2048; i++) {
mantisaTable[i] = 939524096 + (i - 1024 << 13);
}
return mantisaTable;
}
function computeFloat16ExponentTable() {
var exponentTable = new Uint32Array(64);
exponentTable[0] = 0;
exponentTable[31] = 1199570944;
exponentTable[32] = 2147483648;
exponentTable[63] = 3347054592;
for (var i = 1; i < 31; i++) {
exponentTable[i] = i << 23;
}
for (var i = 33; i < 63; i++) {
exponentTable[i] = 2147483648 + (i - 32 << 23);
}
return exponentTable;
}
function computeFloat16OffsetTable() {
var offsetTable = new Uint32Array(64);
for (var i = 0; i < 64; i++) {
offsetTable[i] = 1024;
}
offsetTable[0] = offsetTable[32] = 0;
return offsetTable;
}
function getFloat16Decoder() {
var mantisaTable = computeFloat16MantisaTable();
var exponentTable = computeFloat16ExponentTable();
var offsetTable = computeFloat16OffsetTable();
return function(quantizedArray) {
var buffer2 = new ArrayBuffer(4 * quantizedArray.length);
var bufferUint32View = new Uint32Array(buffer2);
for (var index = 0; index < quantizedArray.length; index++) {
var float16Bits = quantizedArray[index];
var float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 1023)] + exponentTable[float16Bits >> 10];
bufferUint32View[index] = float32Bits;
}
return new Float32Array(buffer2);
};
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var IORouterRegistry = function() {
function IORouterRegistry2() {
this.saveRouters = [];
this.loadRouters = [];
}
IORouterRegistry2.getInstance = function() {
if (IORouterRegistry2.instance == null) {
IORouterRegistry2.instance = new IORouterRegistry2();
}
return IORouterRegistry2.instance;
};
IORouterRegistry2.registerSaveRouter = function(saveRouter) {
IORouterRegistry2.getInstance().saveRouters.push(saveRouter);
};
IORouterRegistry2.registerLoadRouter = function(loadRouter) {
IORouterRegistry2.getInstance().loadRouters.push(loadRouter);
};
IORouterRegistry2.getSaveHandlers = function(url) {
return IORouterRegistry2.getHandlers(url, "save");
};
IORouterRegistry2.getLoadHandlers = function(url, loadOptions) {
return IORouterRegistry2.getHandlers(url, "load", loadOptions);
};
IORouterRegistry2.getHandlers = function(url, handlerType, loadOptions) {
var validHandlers = [];
var routers = handlerType === "load" ? IORouterRegistry2.getInstance().loadRouters : IORouterRegistry2.getInstance().saveRouters;
routers.forEach(function(router) {
var handler = router(url, loadOptions);
if (handler !== null) {
validHandlers.push(handler);
}
});
return validHandlers;
};
return IORouterRegistry2;
}();
var registerSaveRouter = function(loudRouter) {
return IORouterRegistry.registerSaveRouter(loudRouter);
};
var registerLoadRouter = function(loudRouter) {
return IORouterRegistry.registerLoadRouter(loudRouter);
};
var getSaveHandlers = function(url) {
return IORouterRegistry.getSaveHandlers(url);
};
var getLoadHandlers = function(url, loadOptions) {
return IORouterRegistry.getLoadHandlers(url, loadOptions);
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DATABASE_NAME = "tensorflowjs";
var DATABASE_VERSION = 1;
var MODEL_STORE_NAME = "models_store";
var INFO_STORE_NAME = "model_info_store";
function getIndexedDBFactory() {
if (!env().getBool("IS_BROWSER")) {
throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");
}
var theWindow = typeof window === "undefined" ? self : window;
var factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB;
if (factory == null) {
throw new Error("The current browser does not appear to support IndexedDB.");
}
return factory;
}
function setUpDatabase(openRequest) {
var db = openRequest.result;
db.createObjectStore(MODEL_STORE_NAME, {keyPath: "modelPath"});
db.createObjectStore(INFO_STORE_NAME, {keyPath: "modelPath"});
}
var BrowserIndexedDB = function() {
function BrowserIndexedDB2(modelPath) {
this.indexedDB = getIndexedDBFactory();
if (modelPath == null || !modelPath) {
throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
}
this.modelPath = modelPath;
}
BrowserIndexedDB2.prototype.save = function(modelArtifacts) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
}
return [2, this.databaseAction(this.modelPath, modelArtifacts)];
});
});
};
BrowserIndexedDB2.prototype.load = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.databaseAction(this.modelPath)];
});
});
};
BrowserIndexedDB2.prototype.databaseAction = function(modelPath, modelArtifacts) {
var _this = this;
return new Promise(function(resolve, reject) {
var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
openRequest.onupgradeneeded = function() {
return setUpDatabase(openRequest);
};
openRequest.onsuccess = function() {
var db = openRequest.result;
if (modelArtifacts == null) {
var modelTx = db.transaction(MODEL_STORE_NAME, "readonly");
var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
var getRequest_1 = modelStore.get(_this.modelPath);
getRequest_1.onsuccess = function() {
if (getRequest_1.result == null) {
db.close();
return reject(new Error("Cannot find model with path '" + _this.modelPath + "' in IndexedDB."));
} else {
resolve(getRequest_1.result.modelArtifacts);
}
};
getRequest_1.onerror = function(error) {
db.close();
return reject(getRequest_1.error);
};
modelTx.oncomplete = function() {
return db.close();
};
} else {
var modelArtifactsInfo_1 = getModelArtifactsInfoForJSON(modelArtifacts);
var infoTx_1 = db.transaction(INFO_STORE_NAME, "readwrite");
var infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME);
var putInfoRequest_1 = infoStore_1.put({modelPath: _this.modelPath, modelArtifactsInfo: modelArtifactsInfo_1});
var modelTx_1;
putInfoRequest_1.onsuccess = function() {
modelTx_1 = db.transaction(MODEL_STORE_NAME, "readwrite");
var modelStore2 = modelTx_1.objectStore(MODEL_STORE_NAME);
var putModelRequest = modelStore2.put({
modelPath: _this.modelPath,
modelArtifacts,
modelArtifactsInfo: modelArtifactsInfo_1
});
putModelRequest.onsuccess = function() {
return resolve({modelArtifactsInfo: modelArtifactsInfo_1});
};
putModelRequest.onerror = function(error) {
infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME);
var deleteInfoRequest = infoStore_1.delete(_this.modelPath);
deleteInfoRequest.onsuccess = function() {
db.close();
return reject(putModelRequest.error);
};
deleteInfoRequest.onerror = function(error2) {
db.close();
return reject(putModelRequest.error);
};
};
};
putInfoRequest_1.onerror = function(error) {
db.close();
return reject(putInfoRequest_1.error);
};
infoTx_1.oncomplete = function() {
if (modelTx_1 == null) {
db.close();
} else {
modelTx_1.oncomplete = function() {
return db.close();
};
}
};
}
};
openRequest.onerror = function(error) {
return reject(openRequest.error);
};
});
};
BrowserIndexedDB2.URL_SCHEME = "indexeddb://";
return BrowserIndexedDB2;
}();
var indexedDBRouter = function(url) {
if (!env().getBool("IS_BROWSER")) {
return null;
} else {
if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
} else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(indexedDBRouter);
IORouterRegistry.registerLoadRouter(indexedDBRouter);
function browserIndexedDB(modelPath) {
return new BrowserIndexedDB(modelPath);
}
function maybeStripScheme(key) {
return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? key.slice(BrowserIndexedDB.URL_SCHEME.length) : key;
}
var BrowserIndexedDBManager = function() {
function BrowserIndexedDBManager2() {
this.indexedDB = getIndexedDBFactory();
}
BrowserIndexedDBManager2.prototype.listModels = function() {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
return [2, new Promise(function(resolve, reject) {
var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
openRequest.onupgradeneeded = function() {
return setUpDatabase(openRequest);
};
openRequest.onsuccess = function() {
var db = openRequest.result;
var tx = db.transaction(INFO_STORE_NAME, "readonly");
var store = tx.objectStore(INFO_STORE_NAME);
var getAllInfoRequest = store.getAll();
getAllInfoRequest.onsuccess = function() {
var out = {};
for (var _i2 = 0, _a2 = getAllInfoRequest.result; _i2 < _a2.length; _i2++) {
var item = _a2[_i2];
out[item.modelPath] = item.modelArtifactsInfo;
}
resolve(out);
};
getAllInfoRequest.onerror = function(error) {
db.close();
return reject(getAllInfoRequest.error);
};
tx.oncomplete = function() {
return db.close();
};
};
openRequest.onerror = function(error) {
return reject(openRequest.error);
};
})];
});
});
};
BrowserIndexedDBManager2.prototype.removeModel = function(path) {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
path = maybeStripScheme(path);
return [2, new Promise(function(resolve, reject) {
var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
openRequest.onupgradeneeded = function() {
return setUpDatabase(openRequest);
};
openRequest.onsuccess = function() {
var db = openRequest.result;
var infoTx = db.transaction(INFO_STORE_NAME, "readwrite");
var infoStore = infoTx.objectStore(INFO_STORE_NAME);
var getInfoRequest = infoStore.get(path);
var modelTx;
getInfoRequest.onsuccess = function() {
if (getInfoRequest.result == null) {
db.close();
return reject(new Error("Cannot find model with path '" + path + "' in IndexedDB."));
} else {
var deleteInfoRequest = infoStore.delete(path);
var deleteModelData_1 = function() {
modelTx = db.transaction(MODEL_STORE_NAME, "readwrite");
var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
var deleteModelRequest = modelStore.delete(path);
deleteModelRequest.onsuccess = function() {
return resolve(getInfoRequest.result.modelArtifactsInfo);
};
deleteModelRequest.onerror = function(error) {
return reject(getInfoRequest.error);
};
};
deleteInfoRequest.onsuccess = deleteModelData_1;
deleteInfoRequest.onerror = function(error) {
deleteModelData_1();
db.close();
return reject(getInfoRequest.error);
};
}
};
getInfoRequest.onerror = function(error) {
db.close();
return reject(getInfoRequest.error);
};
infoTx.oncomplete = function() {
if (modelTx == null) {
db.close();
} else {
modelTx.oncomplete = function() {
return db.close();
};
}
};
};
openRequest.onerror = function(error) {
return reject(openRequest.error);
};
})];
});
});
};
return BrowserIndexedDBManager2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PATH_SEPARATOR = "/";
var PATH_PREFIX = "tensorflowjs_models";
var INFO_SUFFIX = "info";
var MODEL_TOPOLOGY_SUFFIX = "model_topology";
var WEIGHT_SPECS_SUFFIX = "weight_specs";
var WEIGHT_DATA_SUFFIX = "weight_data";
var MODEL_METADATA_SUFFIX = "model_metadata";
function getModelKeys(path) {
return {
info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
};
}
function getModelPathFromKey(key) {
var items = key.split(PATH_SEPARATOR);
if (items.length < 3) {
throw new Error("Invalid key format: " + key);
}
return items.slice(1, items.length - 1).join(PATH_SEPARATOR);
}
function maybeStripScheme$1(key) {
return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? key.slice(BrowserLocalStorage.URL_SCHEME.length) : key;
}
var BrowserLocalStorage = function() {
function BrowserLocalStorage2(modelPath) {
if (!env().getBool("IS_BROWSER") || typeof window === "undefined" || typeof window.localStorage === "undefined") {
throw new Error("The current environment does not support local storage.");
}
this.LS = window.localStorage;
if (modelPath == null || !modelPath) {
throw new Error("For local storage, modelPath must not be null, undefined or empty.");
}
this.modelPath = modelPath;
this.keys = getModelKeys(this.modelPath);
}
BrowserLocalStorage2.prototype.save = function(modelArtifacts) {
return __awaiter(this, void 0, void 0, function() {
var topology, weightSpecs, modelArtifactsInfo;
return __generator(this, function(_a) {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
} else {
topology = JSON.stringify(modelArtifacts.modelTopology);
weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
try {
this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
this.LS.setItem(this.keys.topology, topology);
this.LS.setItem(this.keys.weightSpecs, weightSpecs);
this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData));
this.LS.setItem(this.keys.modelMetadata, JSON.stringify({
format: modelArtifacts.format,
generatedBy: modelArtifacts.generatedBy,
convertedBy: modelArtifacts.convertedBy,
userDefinedMetadata: modelArtifacts.userDefinedMetadata
}));
return [2, {modelArtifactsInfo}];
} catch (err) {
this.LS.removeItem(this.keys.info);
this.LS.removeItem(this.keys.topology);
this.LS.removeItem(this.keys.weightSpecs);
this.LS.removeItem(this.keys.weightData);
this.LS.removeItem(this.keys.modelMetadata);
throw new Error("Failed to save model '" + this.modelPath + "' to local storage: size quota being exceeded is a possible cause of this failure: " + ("modelTopologyBytes=" + modelArtifactsInfo.modelTopologyBytes + ", ") + ("weightSpecsBytes=" + modelArtifactsInfo.weightSpecsBytes + ", ") + ("weightDataBytes=" + modelArtifactsInfo.weightDataBytes + "."));
}
}
return [2];
});
});
};
BrowserLocalStorage2.prototype.load = function() {
return __awaiter(this, void 0, void 0, function() {
var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64;
return __generator(this, function(_a) {
info = JSON.parse(this.LS.getItem(this.keys.info));
if (info == null) {
throw new Error("In local storage, there is no model with name '" + this.modelPath + "'");
}
if (info.modelTopologyType !== "JSON") {
throw new Error("BrowserLocalStorage does not support loading non-JSON model topology yet.");
}
out = {};
topology = JSON.parse(this.LS.getItem(this.keys.topology));
if (topology == null) {
throw new Error("In local storage, the topology of model '" + this.modelPath + "' is missing.");
}
out.modelTopology = topology;
weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
if (weightSpecs == null) {
throw new Error("In local storage, the weight specs of model '" + this.modelPath + "' are missing.");
}
out.weightSpecs = weightSpecs;
metadataString = this.LS.getItem(this.keys.modelMetadata);
if (metadataString != null) {
metadata = JSON.parse(metadataString);
out.format = metadata["format"];
out.generatedBy = metadata["generatedBy"];
out.convertedBy = metadata["convertedBy"];
out.userDefinedMetadata = metadata["userDefinedMetadata"];
}
weightDataBase64 = this.LS.getItem(this.keys.weightData);
if (weightDataBase64 == null) {
throw new Error("In local storage, the binary weight values of model " + ("'" + this.modelPath + "' are missing."));
}
out.weightData = base64StringToArrayBuffer(weightDataBase64);
return [2, out];
});
});
};
BrowserLocalStorage2.URL_SCHEME = "localstorage://";
return BrowserLocalStorage2;
}();
var localStorageRouter = function(url) {
if (!env().getBool("IS_BROWSER")) {
return null;
} else {
if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
} else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(localStorageRouter);
IORouterRegistry.registerLoadRouter(localStorageRouter);
function browserLocalStorage(modelPath) {
return new BrowserLocalStorage(modelPath);
}
var BrowserLocalStorageManager = function() {
function BrowserLocalStorageManager2() {
assert(env().getBool("IS_BROWSER"), function() {
return "Current environment is not a web browser";
});
assert(typeof window === "undefined" || typeof window.localStorage !== "undefined", function() {
return "Current browser does not appear to support localStorage";
});
this.LS = window.localStorage;
}
BrowserLocalStorageManager2.prototype.listModels = function() {
return __awaiter(this, void 0, void 0, function() {
var out, prefix, suffix, i, key, modelPath;
return __generator(this, function(_a) {
out = {};
prefix = PATH_PREFIX + PATH_SEPARATOR;
suffix = PATH_SEPARATOR + INFO_SUFFIX;
for (i = 0; i < this.LS.length; ++i) {
key = this.LS.key(i);
if (key.startsWith(prefix) && key.endsWith(suffix)) {
modelPath = getModelPathFromKey(key);
out[modelPath] = JSON.parse(this.LS.getItem(key));
}
}
return [2, out];
});
});
};
BrowserLocalStorageManager2.prototype.removeModel = function(path) {
return __awaiter(this, void 0, void 0, function() {
var keys, info;
return __generator(this, function(_a) {
path = maybeStripScheme$1(path);
keys = getModelKeys(path);
if (this.LS.getItem(keys.info) == null) {
throw new Error("Cannot find model at path '" + path + "'");
}
info = JSON.parse(this.LS.getItem(keys.info));
this.LS.removeItem(keys.info);
this.LS.removeItem(keys.topology);
this.LS.removeItem(keys.weightSpecs);
this.LS.removeItem(keys.weightData);
return [2, info];
});
});
};
return BrowserLocalStorageManager2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var URL_SCHEME_SUFFIX = "://";
var ModelStoreManagerRegistry = function() {
function ModelStoreManagerRegistry2() {
this.managers = {};
}
ModelStoreManagerRegistry2.getInstance = function() {
if (ModelStoreManagerRegistry2.instance == null) {
ModelStoreManagerRegistry2.instance = new ModelStoreManagerRegistry2();
}
return ModelStoreManagerRegistry2.instance;
};
ModelStoreManagerRegistry2.registerManager = function(scheme, manager) {
assert(scheme != null, function() {
return "scheme must not be undefined or null.";
});
if (scheme.endsWith(URL_SCHEME_SUFFIX)) {
scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));
}
assert(scheme.length > 0, function() {
return "scheme must not be an empty string.";
});
var registry = ModelStoreManagerRegistry2.getInstance();
assert(registry.managers[scheme] == null, function() {
return "A model store manager is already registered for scheme '" + scheme + "'.";
});
registry.managers[scheme] = manager;
};
ModelStoreManagerRegistry2.getManager = function(scheme) {
var manager = this.getInstance().managers[scheme];
if (manager == null) {
throw new Error("Cannot find model manager for scheme '" + scheme + "'");
}
return manager;
};
ModelStoreManagerRegistry2.getSchemes = function() {
return Object.keys(this.getInstance().managers);
};
return ModelStoreManagerRegistry2;
}();
function parseURL(url) {
if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {
throw new Error("The url string provided does not contain a scheme. Supported schemes are: " + ("" + ModelStoreManagerRegistry.getSchemes().join(",")));
}
return {
scheme: url.split(URL_SCHEME_SUFFIX)[0],
path: url.split(URL_SCHEME_SUFFIX)[1]
};
}
function cloneModelInternal(sourceURL, destURL, deleteSource) {
if (deleteSource === void 0) {
deleteSource = false;
}
return __awaiter(this, void 0, void 0, function() {
var loadHandlers, loadHandler, saveHandlers, saveHandler, sourceScheme, sourcePath, sameMedium, modelArtifacts, saveResult;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
assert(sourceURL !== destURL, function() {
return "Old path and new path are the same: '" + sourceURL + "'";
});
loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);
assert(loadHandlers.length > 0, function() {
return "Copying failed because no load handler is found for source URL " + sourceURL + ".";
});
assert(loadHandlers.length < 2, function() {
return "Copying failed because more than one (" + loadHandlers.length + ") " + ("load handlers for source URL " + sourceURL + ".");
});
loadHandler = loadHandlers[0];
saveHandlers = IORouterRegistry.getSaveHandlers(destURL);
assert(saveHandlers.length > 0, function() {
return "Copying failed because no save handler is found for destination " + ("URL " + destURL + ".");
});
assert(saveHandlers.length < 2, function() {
return "Copying failed because more than one (" + loadHandlers.length + ") " + ("save handlers for destination URL " + destURL + ".");
});
saveHandler = saveHandlers[0];
sourceScheme = parseURL(sourceURL).scheme;
sourcePath = parseURL(sourceURL).path;
sameMedium = sourceScheme === parseURL(sourceURL).scheme;
return [4, loadHandler.load()];
case 1:
modelArtifacts = _a.sent();
if (!(deleteSource && sameMedium))
return [3, 3];
return [4, ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath)];
case 2:
_a.sent();
_a.label = 3;
case 3:
return [4, saveHandler.save(modelArtifacts)];
case 4:
saveResult = _a.sent();
if (!(deleteSource && !sameMedium))
return [3, 6];
return [4, ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath)];
case 5:
_a.sent();
_a.label = 6;
case 6:
return [2, saveResult.modelArtifactsInfo];
}
});
});
}
function listModels() {
return __awaiter(this, void 0, void 0, function() {
var schemes, out, _i2, schemes_1, scheme, schemeOut, path, url;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
schemes = ModelStoreManagerRegistry.getSchemes();
out = {};
_i2 = 0, schemes_1 = schemes;
_a.label = 1;
case 1:
if (!(_i2 < schemes_1.length))
return [3, 4];
scheme = schemes_1[_i2];
return [4, ModelStoreManagerRegistry.getManager(scheme).listModels()];
case 2:
schemeOut = _a.sent();
for (path in schemeOut) {
url = scheme + URL_SCHEME_SUFFIX + path;
out[url] = schemeOut[path];
}
_a.label = 3;
case 3:
_i2++;
return [3, 1];
case 4:
return [2, out];
}
});
});
}
function removeModel(url) {
return __awaiter(this, void 0, void 0, function() {
var schemeAndPath, manager;
return __generator(this, function(_a) {
schemeAndPath = parseURL(url);
manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);
return [2, manager.removeModel(schemeAndPath.path)];
});
});
}
function copyModel(sourceURL, destURL) {
return __awaiter(this, void 0, void 0, function() {
var deleteSource;
return __generator(this, function(_a) {
deleteSource = false;
return [2, cloneModelInternal(sourceURL, destURL, deleteSource)];
});
});
}
function moveModel(sourceURL, destURL) {
return __awaiter(this, void 0, void 0, function() {
var deleteSource;
return __generator(this, function(_a) {
deleteSource = true;
return [2, cloneModelInternal(sourceURL, destURL, deleteSource)];
});
});
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PlatformBrowser = function() {
function PlatformBrowser2() {
}
PlatformBrowser2.prototype.fetch = function(path, init) {
return fetch(path, init);
};
PlatformBrowser2.prototype.now = function() {
return performance.now();
};
PlatformBrowser2.prototype.encode = function(text, encoding) {
if (encoding !== "utf-8" && encoding !== "utf8") {
throw new Error("Browser's encoder only supports utf-8, but got " + encoding);
}
if (this.textEncoder == null) {
this.textEncoder = new TextEncoder();
}
return this.textEncoder.encode(text);
};
PlatformBrowser2.prototype.decode = function(bytes, encoding) {
return new TextDecoder(encoding).decode(bytes);
};
return PlatformBrowser2;
}();
if (env().get("IS_BROWSER")) {
env().setPlatform("browser", new PlatformBrowser());
try {
ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());
} catch (err) {
}
try {
ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());
} catch (err) {
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var getNodeFetch = {
importFetch: function() {
return require_browser();
}
};
var systemFetch;
var PlatformNode = function() {
function PlatformNode2() {
this.util = require_util();
this.textEncoder = new this.util.TextEncoder();
}
PlatformNode2.prototype.fetch = function(path, requestInits) {
if (env().global.fetch != null) {
return env().global.fetch(path, requestInits);
}
if (systemFetch == null) {
systemFetch = getNodeFetch.importFetch();
}
return systemFetch(path, requestInits);
};
PlatformNode2.prototype.now = function() {
var time2 = process.hrtime();
return time2[0] * 1e3 + time2[1] / 1e6;
};
PlatformNode2.prototype.encode = function(text, encoding) {
if (encoding !== "utf-8" && encoding !== "utf8") {
throw new Error("Node built-in encoder only supports utf-8, but got " + encoding);
}
return this.textEncoder.encode(text);
};
PlatformNode2.prototype.decode = function(bytes, encoding) {
if (bytes.length === 0) {
return "";
}
return new this.util.TextDecoder(encoding).decode(bytes);
};
return PlatformNode2;
}();
if (env().get("IS_NODE")) {
env().setPlatform("node", new PlatformNode());
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function buffer(shape, dtype, values) {
if (dtype === void 0) {
dtype = "float32";
}
dtype = dtype || "float32";
assertNonNegativeIntegerDimensions(shape);
return new TensorBuffer(shape, dtype, values);
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cast_(x, dtype) {
var $x = convertToTensor(x, "x", "cast");
if (!isValidDtype(dtype)) {
throw new Error("Failed to cast to unknown dtype " + dtype);
}
if (dtype === "string" && $x.dtype !== "string" || dtype !== "string" && $x.dtype === "string") {
throw new Error("Only strings can be casted to strings");
}
var inputs = {x: $x};
var attrs = {dtype};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.cast($x, dtype);
}, inputs, null, Cast, attrs);
}
var cast = op({cast_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function clone_(x) {
var $x = convertToTensor(x, "x", "clone", null);
var forward = function() {
return ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype);
};
var inputs = {x: $x};
return ENGINE.runKernelFunc(forward, inputs, null, Identity);
}
var clone = op({clone_});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function print(x, verbose) {
if (verbose === void 0) {
verbose = false;
}
console.log(x.toString(verbose));
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getOrMakeEngine();
var opHandler$1 = {
buffer,
cast,
clone,
print
};
setOpHandler(opHandler$1);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DEFAULT_FILE_NAME_PREFIX = "model";
var DEFAULT_JSON_EXTENSION_NAME = ".json";
var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = ".weights.bin";
function defer(f) {
return new Promise(function(resolve) {
return setTimeout(resolve);
}).then(f);
}
var BrowserDownloads = function() {
function BrowserDownloads2(fileNamePrefix) {
if (!env().getBool("IS_BROWSER")) {
throw new Error("browserDownloads() cannot proceed because the current environment is not a browser.");
}
if (fileNamePrefix.startsWith(BrowserDownloads2.URL_SCHEME)) {
fileNamePrefix = fileNamePrefix.slice(BrowserDownloads2.URL_SCHEME.length);
}
if (fileNamePrefix == null || fileNamePrefix.length === 0) {
fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
}
this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
}
BrowserDownloads2.prototype.save = function(modelArtifacts) {
return __awaiter(this, void 0, void 0, function() {
var weightsURL, weightsManifest, modelTopologyAndWeightManifest, modelTopologyAndWeightManifestURL, jsonAnchor_1, weightDataAnchor_1;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (typeof document === "undefined") {
throw new Error("Browser downloads are not supported in this environment since `document` is not present");
}
weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], {type: "application/octet-stream"}));
if (!(modelArtifacts.modelTopology instanceof ArrayBuffer))
return [3, 1];
throw new Error("BrowserDownloads.save() does not support saving model topology in binary formats yet.");
case 1:
weightsManifest = [{
paths: ["./" + this.weightDataFileName],
weights: modelArtifacts.weightSpecs
}];
modelTopologyAndWeightManifest = {
modelTopology: modelArtifacts.modelTopology,
format: modelArtifacts.format,
generatedBy: modelArtifacts.generatedBy,
convertedBy: modelArtifacts.convertedBy,
weightsManifest
};
modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {type: "application/json"}));
jsonAnchor_1 = this.jsonAnchor == null ? document.createElement("a") : this.jsonAnchor;
jsonAnchor_1.download = this.modelTopologyFileName;
jsonAnchor_1.href = modelTopologyAndWeightManifestURL;
return [4, defer(function() {
return jsonAnchor_1.dispatchEvent(new MouseEvent("click"));
})];
case 2:
_a.sent();
if (!(modelArtifacts.weightData != null))
return [3, 4];
weightDataAnchor_1 = this.weightDataAnchor == null ? document.createElement("a") : this.weightDataAnchor;
weightDataAnchor_1.download = this.weightDataFileName;
weightDataAnchor_1.href = weightsURL;
return [4, defer(function() {
return weightDataAnchor_1.dispatchEvent(new MouseEvent("click"));
})];
case 3:
_a.sent();
_a.label = 4;
case 4:
return [2, {modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts)}];
}
});
});
};
BrowserDownloads2.URL_SCHEME = "downloads://";
return BrowserDownloads2;
}();
var BrowserFiles = function() {
function BrowserFiles2(files) {
if (files == null || files.length < 1) {
throw new Error("When calling browserFiles, at least 1 file is required, " + ("but received " + files));
}
this.files = files;
}
BrowserFiles2.prototype.load = function() {
return __awaiter(this, void 0, void 0, function() {
var jsonFile, weightFiles;
var _this = this;
return __generator(this, function(_a) {
jsonFile = this.files[0];
weightFiles = this.files.slice(1);
return [2, new Promise(function(resolve, reject) {
var jsonReader = new FileReader();
jsonReader.onload = function(event) {
var modelJSON = JSON.parse(event.target.result);
var modelTopology = modelJSON.modelTopology;
if (modelTopology == null) {
reject(new Error("modelTopology field is missing from file " + jsonFile.name));
return;
}
if (weightFiles.length === 0) {
resolve({modelTopology});
}
var weightsManifest = modelJSON.weightsManifest;
if (weightsManifest == null) {
reject(new Error("weightManifest field is missing from file " + jsonFile.name));
return;
}
var pathToFile;
try {
pathToFile = _this.checkManifestAndWeightFiles(weightsManifest, weightFiles);
} catch (err) {
reject(err);
return;
}
var weightSpecs = [];
var paths = [];
var perFileBuffers = [];
weightsManifest.forEach(function(weightsGroup) {
weightsGroup.paths.forEach(function(path) {
paths.push(path);
perFileBuffers.push(null);
});
weightSpecs.push.apply(weightSpecs, weightsGroup.weights);
});
weightsManifest.forEach(function(weightsGroup) {
weightsGroup.paths.forEach(function(path) {
var weightFileReader = new FileReader();
weightFileReader.onload = function(event2) {
var weightData = event2.target.result;
var index = paths.indexOf(path);
perFileBuffers[index] = weightData;
if (perFileBuffers.indexOf(null) === -1) {
resolve({
modelTopology,
weightSpecs,
weightData: concatenateArrayBuffers(perFileBuffers),
format: modelJSON.format,
generatedBy: modelJSON.generatedBy,
convertedBy: modelJSON.convertedBy,
userDefinedMetadata: modelJSON.userDefinedMetadata
});
}
};
weightFileReader.onerror = function(error) {
return reject("Failed to weights data from file of path '" + path + "'.");
};
weightFileReader.readAsArrayBuffer(pathToFile[path]);
});
});
};
jsonReader.onerror = function(error) {
return reject("Failed to read model topology and weights manifest JSON " + ("from file '" + jsonFile.name + "'. BrowserFiles supports loading ") + "Keras-style tf.Model artifacts only.");
};
jsonReader.readAsText(jsonFile);
})];
});
});
};
BrowserFiles2.prototype.checkManifestAndWeightFiles = function(manifest, files) {
var basenames = [];
var fileNames = files.map(function(file) {
return basename(file.name);
});
var pathToFile = {};
for (var _i2 = 0, manifest_1 = manifest; _i2 < manifest_1.length; _i2++) {
var group = manifest_1[_i2];
group.paths.forEach(function(path) {
var pathBasename = basename(path);
if (basenames.indexOf(pathBasename) !== -1) {
throw new Error("Duplicate file basename found in weights manifest: " + ("'" + pathBasename + "'"));
}
basenames.push(pathBasename);
if (fileNames.indexOf(pathBasename) === -1) {
throw new Error("Weight file with basename '" + pathBasename + "' is not provided.");
} else {
pathToFile[path] = files[fileNames.indexOf(pathBasename)];
}
});
}
if (basenames.length !== files.length) {
throw new Error("Mismatch in the number of files in weights manifest " + ("(" + basenames.length + ") and the number of weight files provided ") + ("(" + files.length + ")."));
}
return pathToFile;
};
return BrowserFiles2;
}();
var browserDownloadsRouter = function(url) {
if (!env().getBool("IS_BROWSER")) {
return null;
} else {
if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
} else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
function browserDownloads(fileNamePrefix) {
if (fileNamePrefix === void 0) {
fileNamePrefix = "model";
}
return new BrowserDownloads(fileNamePrefix);
}
function browserFiles(files) {
return new BrowserFiles(files);
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) {
checkPromises(promises);
startFraction = startFraction == null ? 0 : startFraction;
endFraction = endFraction == null ? 1 : endFraction;
checkFraction(startFraction, endFraction);
var resolvedPromise = 0;
var registerMonitor = function(promise) {
promise.then(function(value) {
var fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction);
onProgress(fraction);
return value;
});
return promise;
};
function checkPromises(promises2) {
assert(promises2 != null && Array.isArray(promises2) && promises2.length > 0, function() {
return "promises must be a none empty array";
});
}
function checkFraction(startFraction2, endFraction2) {
assert(startFraction2 >= 0 && startFraction2 <= 1, function() {
return "Progress fraction must be in range [0, 1], but " + ("got startFraction " + startFraction2);
});
assert(endFraction2 >= 0 && endFraction2 <= 1, function() {
return "Progress fraction must be in range [0, 1], but " + ("got endFraction " + endFraction2);
});
assert(endFraction2 >= startFraction2, function() {
return "startFraction must be no more than endFraction, but " + ("got startFraction " + startFraction2 + " and endFraction ") + ("" + endFraction2);
});
}
return Promise.all(promises.map(registerMonitor));
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) {
return __awaiter(this, void 0, void 0, function() {
var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _a, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _b;
return __generator(this, function(_c) {
switch (_c.label) {
case 0:
if (loadOptions == null) {
loadOptions = {};
}
fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc;
requests = fetchURLs.map(function(fetchURL) {
return fetchFunc(fetchURL, loadOptions.requestInit, {isBinary: true});
});
fetchStartFraction = 0;
fetchEndFraction = 0.5;
if (!(loadOptions.onProgress == null))
return [3, 2];
return [4, Promise.all(requests)];
case 1:
_a = _c.sent();
return [3, 4];
case 2:
return [4, monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)];
case 3:
_a = _c.sent();
_c.label = 4;
case 4:
responses = _a;
bufferPromises = responses.map(function(response) {
return response.arrayBuffer();
});
bufferStartFraction = 0.5;
bufferEndFraction = 1;
if (!(loadOptions.onProgress == null))
return [3, 6];
return [4, Promise.all(bufferPromises)];
case 5:
_b = _c.sent();
return [3, 8];
case 6:
return [4, monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)];
case 7:
_b = _c.sent();
_c.label = 8;
case 8:
buffers = _b;
return [2, buffers];
}
});
});
}
function loadWeights(manifest, filePathPrefix, weightNames, requestInit) {
if (filePathPrefix === void 0) {
filePathPrefix = "";
}
return __awaiter(this, void 0, void 0, function() {
var fetchWeights, loadWeights2;
return __generator(this, function(_a) {
fetchWeights = function(fetchUrls) {
return loadWeightsAsArrayBuffer(fetchUrls, {requestInit});
};
loadWeights2 = weightsLoaderFactory(fetchWeights);
return [2, loadWeights2(manifest, filePathPrefix, weightNames)];
});
});
}
function weightsLoaderFactory(fetchWeightsFunction) {
var _this = this;
return function(manifest, filePathPrefix, weightNames) {
if (filePathPrefix === void 0) {
filePathPrefix = "";
}
return __awaiter(_this, void 0, void 0, function() {
var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
groupIndicesToFetchMap = manifest.map(function() {
return false;
});
groupWeightsToFetch = {};
weightsFound = weightNames != null ? weightNames.map(function() {
return false;
}) : [];
allManifestWeightNames = [];
manifest.forEach(function(manifestGroupConfig, groupIndex) {
var groupOffset = 0;
manifestGroupConfig.weights.forEach(function(weightsEntry) {
var rawDtype = "quantization" in weightsEntry ? weightsEntry.quantization.dtype : weightsEntry.dtype;
var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * sizeFromShape(weightsEntry.shape);
var enqueueWeightsForFetchingFn = function() {
groupIndicesToFetchMap[groupIndex] = true;
if (groupWeightsToFetch[groupIndex] == null) {
groupWeightsToFetch[groupIndex] = [];
}
groupWeightsToFetch[groupIndex].push({
manifestEntry: weightsEntry,
groupOffset,
sizeBytes: weightsBytes
});
};
if (weightNames != null) {
weightNames.forEach(function(weightName, weightIndex) {
if (weightName === weightsEntry.name) {
enqueueWeightsForFetchingFn();
weightsFound[weightIndex] = true;
}
});
} else {
enqueueWeightsForFetchingFn();
}
allManifestWeightNames.push(weightsEntry.name);
groupOffset += weightsBytes;
});
});
if (!weightsFound.every(function(found) {
return found;
})) {
weightsNotFound = weightNames.filter(function(_, i) {
return !weightsFound[i];
});
throw new Error("Could not find weights in manifest with names: " + (weightsNotFound.join(", ") + ". \n") + "Manifest JSON has weights with names: " + (allManifestWeightNames.join(", ") + "."));
}
groupIndicesToFetch = groupIndicesToFetchMap.reduce(function(accumulator, shouldFetch, i) {
if (shouldFetch) {
accumulator.push(i);
}
return accumulator;
}, []);
fetchUrls = [];
groupIndicesToFetch.forEach(function(i) {
manifest[i].paths.forEach(function(filepath) {
var fetchUrl = filePathPrefix + (!filePathPrefix.endsWith("/") ? "/" : "") + filepath;
fetchUrls.push(fetchUrl);
});
});
return [4, fetchWeightsFunction(fetchUrls)];
case 1:
buffers = _a.sent();
weightsTensorMap = {};
bufferIndexOffset = 0;
groupIndicesToFetch.forEach(function(i) {
var numBuffers = manifest[i].paths.length;
var groupBytes = 0;
for (var i_1 = 0; i_1 < numBuffers; i_1++) {
groupBytes += buffers[bufferIndexOffset + i_1].byteLength;
}
var groupBuffer = new ArrayBuffer(groupBytes);
var groupByteBuffer = new Uint8Array(groupBuffer);
var groupBufferOffset = 0;
for (var i_2 = 0; i_2 < numBuffers; i_2++) {
var buffer2 = new Uint8Array(buffers[bufferIndexOffset + i_2]);
groupByteBuffer.set(buffer2, groupBufferOffset);
groupBufferOffset += buffer2.byteLength;
}
var weightsEntries = groupWeightsToFetch[i];
weightsEntries.forEach(function(weightsEntry) {
var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
for (var name_1 in nameToTensorMap) {
weightsTensorMap[name_1] = nameToTensorMap[name_1];
}
});
bufferIndexOffset += numBuffers;
});
return [2, weightsTensorMap];
}
});
});
};
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var OCTET_STREAM_MIME_TYPE = "application/octet-stream";
var JSON_TYPE = "application/json";
var HTTPRequest = function() {
function HTTPRequest2(path, loadOptions) {
this.DEFAULT_METHOD = "POST";
if (loadOptions == null) {
loadOptions = {};
}
this.weightPathPrefix = loadOptions.weightPathPrefix;
this.onProgress = loadOptions.onProgress;
this.weightUrlConverter = loadOptions.weightUrlConverter;
if (loadOptions.fetchFunc != null) {
assert(typeof loadOptions.fetchFunc === "function", function() {
return "Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)";
});
this.fetch = loadOptions.fetchFunc;
} else {
this.fetch = env().platform.fetch;
}
assert(path != null && path.length > 0, function() {
return "URL path for http must not be null, undefined or empty.";
});
if (Array.isArray(path)) {
assert(path.length === 2, function() {
return "URL paths for http must have a length of 2, " + ("(actual length is " + path.length + ").");
});
}
this.path = path;
if (loadOptions.requestInit != null && loadOptions.requestInit.body != null) {
throw new Error("requestInit is expected to have no pre-existing body, but has one.");
}
this.requestInit = loadOptions.requestInit || {};
}
HTTPRequest2.prototype.save = function(modelArtifacts) {
return __awaiter(this, void 0, void 0, function() {
var init, weightsManifest, modelTopologyAndWeightManifest, response;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error("BrowserHTTPRequest.save() does not support saving model topology in binary formats yet.");
}
init = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit);
init.body = new FormData();
weightsManifest = [{
paths: ["./model.weights.bin"],
weights: modelArtifacts.weightSpecs
}];
modelTopologyAndWeightManifest = {
modelTopology: modelArtifacts.modelTopology,
format: modelArtifacts.format,
generatedBy: modelArtifacts.generatedBy,
convertedBy: modelArtifacts.convertedBy,
userDefinedMetadata: modelArtifacts.userDefinedMetadata,
weightsManifest
};
init.body.append("model.json", new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {type: JSON_TYPE}), "model.json");
if (modelArtifacts.weightData != null) {
init.body.append("model.weights.bin", new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}), "model.weights.bin");
}
return [4, this.fetch(this.path, init)];
case 1:
response = _a.sent();
if (response.ok) {
return [2, {
modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),
responses: [response]
}];
} else {
throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + (response.status + "."));
}
}
});
});
};
HTTPRequest2.prototype.load = function() {
return __awaiter(this, void 0, void 0, function() {
var modelConfigRequest, modelConfig, e_1, message, modelTopology, weightsManifest, generatedBy, convertedBy, format, userDefinedMetadata, weightSpecs, weightData, results, artifacts, initializer;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.fetch(this.path, this.requestInit)];
case 1:
modelConfigRequest = _a.sent();
if (!modelConfigRequest.ok) {
throw new Error("Request to " + this.path + " failed with status code " + (modelConfigRequest.status + ". Please verify this URL points to ") + "the model JSON of the model to load.");
}
_a.label = 2;
case 2:
_a.trys.push([2, 4, , 5]);
return [4, modelConfigRequest.json()];
case 3:
modelConfig = _a.sent();
return [3, 5];
case 4:
e_1 = _a.sent();
message = "Failed to parse model JSON of response from " + this.path + ".";
if (this.path.endsWith(".pb")) {
message += " Your path contains a .pb file extension. Support for .pb models have been removed in TensorFlow.js 1.0 in favor of .json models. You can re-convert your Python TensorFlow model using the TensorFlow.js 1.0 conversion scripts or you can convert your.pb models with the 'pb2json'NPM script in the tensorflow/tfjs-converter repository.";
} else {
message += " Please make sure the server is serving valid JSON for this request.";
}
throw new Error(message);
case 5:
modelTopology = modelConfig.modelTopology;
weightsManifest = modelConfig.weightsManifest;
generatedBy = modelConfig.generatedBy;
convertedBy = modelConfig.convertedBy;
format = modelConfig.format;
userDefinedMetadata = modelConfig.userDefinedMetadata;
if (modelTopology == null && weightsManifest == null) {
throw new Error("The JSON from HTTP path " + this.path + " contains neither model topology or manifest for weights.");
}
if (!(weightsManifest != null))
return [3, 7];
return [4, this.loadWeights(weightsManifest)];
case 6:
results = _a.sent();
weightSpecs = results[0], weightData = results[1];
_a.label = 7;
case 7:
artifacts = {
modelTopology,
weightSpecs,
weightData,
userDefinedMetadata,
generatedBy,
convertedBy,
format
};
initializer = modelConfig.modelInitializer;
if (initializer) {
artifacts.modelInitializer = initializer;
}
return [2, artifacts];
}
});
});
};
HTTPRequest2.prototype.loadWeights = function(weightsManifest) {
return __awaiter(this, void 0, void 0, function() {
var weightPath, _a, prefix, suffix, pathPrefix, weightSpecs, _i2, weightsManifest_1, entry, fetchURLs, urlPromises, _b, weightsManifest_2, weightsGroup, _c, _d, path, _e, _f, _g, buffers;
return __generator(this, function(_h) {
switch (_h.label) {
case 0:
weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
_a = parseUrl(weightPath), prefix = _a[0], suffix = _a[1];
pathPrefix = this.weightPathPrefix || prefix;
weightSpecs = [];
for (_i2 = 0, weightsManifest_1 = weightsManifest; _i2 < weightsManifest_1.length; _i2++) {
entry = weightsManifest_1[_i2];
weightSpecs.push.apply(weightSpecs, entry.weights);
}
fetchURLs = [];
urlPromises = [];
for (_b = 0, weightsManifest_2 = weightsManifest; _b < weightsManifest_2.length; _b++) {
weightsGroup = weightsManifest_2[_b];
for (_c = 0, _d = weightsGroup.paths; _c < _d.length; _c++) {
path = _d[_c];
if (this.weightUrlConverter != null) {
urlPromises.push(this.weightUrlConverter(path));
} else {
fetchURLs.push(pathPrefix + path + suffix);
}
}
}
if (!this.weightUrlConverter)
return [3, 2];
_f = (_e = fetchURLs.push).apply;
_g = [fetchURLs];
return [4, Promise.all(urlPromises)];
case 1:
_f.apply(_e, _g.concat([_h.sent()]));
_h.label = 2;
case 2:
return [4, loadWeightsAsArrayBuffer(fetchURLs, {
requestInit: this.requestInit,
fetchFunc: this.fetch,
onProgress: this.onProgress
})];
case 3:
buffers = _h.sent();
return [2, [weightSpecs, concatenateArrayBuffers(buffers)]];
}
});
});
};
HTTPRequest2.URL_SCHEME_REGEX = /^https?:\/\//;
return HTTPRequest2;
}();
function parseUrl(url) {
var lastSlash = url.lastIndexOf("/");
var lastSearchParam = url.lastIndexOf("?");
var prefix = url.substring(0, lastSlash);
var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : "";
return [prefix + "/", suffix];
}
function isHTTPScheme(url) {
return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
}
var httpRouter = function(url, loadOptions) {
if (typeof fetch === "undefined" && (loadOptions == null || loadOptions.fetchFunc == null)) {
return null;
} else {
var isHTTP = true;
if (Array.isArray(url)) {
isHTTP = url.every(function(urlItem) {
return isHTTPScheme(urlItem);
});
} else {
isHTTP = isHTTPScheme(url);
}
if (isHTTP) {
return http(url, loadOptions);
}
}
return null;
};
IORouterRegistry.registerSaveRouter(httpRouter);
IORouterRegistry.registerLoadRouter(httpRouter);
function http(path, loadOptions) {
return new HTTPRequest(path, loadOptions);
}
function browserHTTPRequest(path, loadOptions) {
return http(path, loadOptions);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PassthroughLoader = function() {
function PassthroughLoader2(modelArtifacts) {
this.modelArtifacts = modelArtifacts;
}
PassthroughLoader2.prototype.load = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.modelArtifacts];
});
});
};
return PassthroughLoader2;
}();
var PassthroughSaver = function() {
function PassthroughSaver2(saveHandler) {
this.saveHandler = saveHandler;
}
PassthroughSaver2.prototype.save = function(modelArtifacts) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.saveHandler(modelArtifacts)];
});
});
};
return PassthroughSaver2;
}();
function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
if (arguments.length === 1) {
var isModelArtifacts = modelArtifacts.modelTopology != null || modelArtifacts.weightSpecs != null;
if (isModelArtifacts) {
return new PassthroughLoader(modelArtifacts);
} else {
console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release.");
return new PassthroughLoader({modelTopology: modelArtifacts});
}
} else {
console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release.");
return new PassthroughLoader({
modelTopology: modelArtifacts,
weightSpecs,
weightData,
trainingConfig
});
}
}
function withSaveHandler(saveHandler) {
return new PassthroughSaver(saveHandler);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var io = {
__proto__: null,
browserFiles,
browserHTTPRequest,
concatenateArrayBuffers,
decodeWeights,
encodeWeights,
fromMemory,
getLoadHandlers,
getModelArtifactsInfoForJSON,
getSaveHandlers,
http,
isHTTPScheme,
loadWeights,
registerLoadRouter,
registerSaveRouter,
weightsLoaderFactory,
withSaveHandler,
copyModel,
listModels,
moveModel,
removeModel
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reshape_(x, shape) {
var $x = convertToTensor(x, "x", "reshape", null);
var inputs = {x: $x};
var attrs = {shape};
var forward = function(backend2, save) {
shape = inferFromImplicitShape(shape, $x.size);
assert($x.size === sizeFromShape(shape), function() {
return "new shape and old shape must have the same number of elements.";
});
save([$x]);
return backend2.reshape($x, shape);
};
return ENGINE.runKernelFunc(forward, inputs, null, Reshape, attrs);
}
var reshape = op({reshape_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function matMul_(a, b, transposeA, transposeB) {
var _a;
if (transposeA === void 0) {
transposeA = false;
}
if (transposeB === void 0) {
transposeB = false;
}
var $a = convertToTensor(a, "a", "matMul");
var $b = convertToTensor(b, "b", "matMul");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forward = function(backend2, save) {
save([$a, $b]);
var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
var outerDimsA = $a.shape.slice(0, -2);
var outerDimsB = $b.shape.slice(0, -2);
var batchDimA = sizeFromShape(outerDimsA);
var batchDimB = sizeFromShape(outerDimsB);
var batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1;
assert($a.rank >= 2 && $b.rank >= 2 && batchDimsCompatible, function() {
return "Error in matMul: the input batch dimensions must either be the same or at least one input batch dimension must be 1. Got input " + ("batch dimensions of (" + outerDimsA + ") and (" + outerDimsB + ").");
});
assert(innerShapeA === innerShapeB, function() {
return "Error in matMul: inner shapes (" + innerShapeA + ") and (" + (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") + ($b.shape + " and transposeA=" + transposeA) + (" and transposeB=" + transposeB + " must match.");
});
var outShapeOuterDims = batchDimA > batchDimB ? outerDimsA : outerDimsB;
var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
var a3D = transposeA ? reshape($a, [batchDimA, innerShapeA, outerShapeA]) : reshape($a, [batchDimA, outerShapeA, innerShapeA]);
var b3D = transposeB ? reshape($b, [batchDimB, outerShapeB, innerShapeB]) : reshape($b, [batchDimB, innerShapeB, outerShapeB]);
var res3d = backend2.batchMatMul(a3D, b3D, transposeA, transposeB);
return reshape(res3d, outShape);
};
var inputs = {a: $a, b: $b};
var attrs = {transposeA, transposeB};
return ENGINE.runKernelFunc(forward, inputs, null, BatchMatMul, attrs);
}
var matMul = op({matMul_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function oneHot_(indices, depth, onValue, offValue) {
if (onValue === void 0) {
onValue = 1;
}
if (offValue === void 0) {
offValue = 0;
}
if (depth < 2) {
throw new Error("Error in oneHot: depth must be >=2, but it is " + depth);
}
var $indices = convertToTensor(indices, "indices", "oneHot", "int32");
var outShape = $indices.shape.concat([depth]);
var forward = function(backend2, save) {
save([$indices]);
return reshape(backend2.oneHot(reshape($indices, [$indices.size]), depth, onValue, offValue), outShape);
};
var inputs = {indices: $indices};
var attrs = {depth, onValue, offValue};
return ENGINE.runKernelFunc(forward, inputs, null, OneHot, attrs);
}
var oneHot = op({oneHot_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transpose_(x, perm) {
var $x = convertToTensor(x, "x", "transpose");
if (perm == null) {
perm = $x.shape.map(function(s, i) {
return i;
}).reverse();
}
assert($x.rank === perm.length, function() {
return "Error in transpose: rank of input " + $x.rank + " " + ("must match length of perm " + perm + ".");
});
perm.forEach(function(axis) {
assert(axis >= 0 && axis < $x.rank, function() {
return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) + (" but got " + perm);
});
});
if ($x.rank <= 1) {
return $x.clone();
}
var inputs = {x: $x};
var attrs = {perm};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.transpose($x, perm);
}, inputs, null, Transpose, attrs);
}
var transpose = op({transpose_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function confusionMatrix_(labels, predictions, numClasses) {
var $labels = convertToTensor(labels, "labels", "confusionMatrix");
var $predictions = convertToTensor(predictions, "predictions", "confusionMatrix");
assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function() {
return "If provided, numClasses must be a positive integer, " + ("but got " + numClasses);
});
assert($labels.rank === 1, function() {
return "Expected the rank of labels to be 1, but got " + $labels.rank;
});
assert($predictions.rank === 1, function() {
return "Expected the rank of predictions to be 1, " + ("but got " + $predictions.rank);
});
assert($labels.shape[0] === $predictions.shape[0], function() {
return "Mismatch in the number of examples: " + ($labels.shape[0] + " vs. " + $predictions.shape[0] + ". ") + "Labels and predictions should have the same number of elements.";
});
assert(numClasses > 0 && Number.isInteger(numClasses), function() {
return "numClasses is required to be a positive integer, but got " + ("" + numClasses);
});
var oneHotLabels = oneHot(cast($labels, "int32"), numClasses);
var oneHotPredictions = oneHot(cast($predictions, "int32"), numClasses);
var oneHotLabelsT = transpose(oneHotLabels);
var product = matMul(oneHotLabelsT, oneHotPredictions);
return cast(product, "int32");
}
var confusionMatrix = op({confusionMatrix_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var math = {
__proto__: null,
confusionMatrix
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tensor3d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 3) {
throw new Error("tensor3d() requires shape to have three numbers");
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 3 && inferredShape.length !== 1) {
throw new Error("tensor3d() requires values to be number[][][] or flat/TypedArray");
}
if (inferredShape.length === 1 && shape == null) {
throw new Error("tensor3d() requires shape to be provided when `values` are a flat array");
}
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var fromPixels2DContext;
function fromPixels_(pixels, numChannels) {
if (numChannels === void 0) {
numChannels = 3;
}
if (numChannels > 4) {
throw new Error("Cannot construct Tensor with more than 4 channels from pixels.");
}
if (pixels == null) {
throw new Error("pixels passed to tf.browser.fromPixels() can not be null");
}
var isPixelData = false;
var isImageData = false;
var isVideo = false;
var isImage = false;
var isCanvasLike = false;
if (pixels.data instanceof Uint8Array) {
isPixelData = true;
} else if (typeof ImageData !== "undefined" && pixels instanceof ImageData) {
isImageData = true;
} else if (typeof HTMLVideoElement !== "undefined" && pixels instanceof HTMLVideoElement) {
isVideo = true;
} else if (typeof HTMLImageElement !== "undefined" && pixels instanceof HTMLImageElement) {
isImage = true;
} else if (pixels.getContext != null) {
isCanvasLike = true;
} else {
throw new Error("pixels passed to tf.browser.fromPixels() must be either an HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData in browser, or OffscreenCanvas, ImageData in webworker or {data: Uint32Array, width: number, height: number}, " + ("but was " + pixels.constructor.name));
}
if (isVideo) {
var HAVE_CURRENT_DATA_READY_STATE = 2;
if (isVideo && pixels.readyState < HAVE_CURRENT_DATA_READY_STATE) {
throw new Error("The video element has not loaded data yet. Please wait for `loadeddata` event on the <video> element.");
}
}
var kernel = getKernel(FromPixels, ENGINE.backendName);
if (kernel != null) {
var inputs = {pixels};
var attrs = {numChannels};
return ENGINE.runKernel(FromPixels, inputs, attrs);
}
var _a = isVideo ? [
pixels.videoWidth,
pixels.videoHeight
] : [pixels.width, pixels.height], width = _a[0], height = _a[1];
var vals;
if (isCanvasLike) {
vals = pixels.getContext("2d").getImageData(0, 0, width, height).data;
} else if (isImageData || isPixelData) {
vals = pixels.data;
} else if (isImage || isVideo) {
if (fromPixels2DContext == null) {
fromPixels2DContext = document.createElement("canvas").getContext("2d");
}
fromPixels2DContext.canvas.width = width;
fromPixels2DContext.canvas.height = height;
fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
vals = fromPixels2DContext.getImageData(0, 0, width, height).data;
}
var values;
if (numChannels === 4) {
values = new Int32Array(vals);
} else {
var numPixels = width * height;
values = new Int32Array(numPixels * numChannels);
for (var i = 0; i < numPixels; i++) {
for (var channel = 0; channel < numChannels; ++channel) {
values[i * numChannels + channel] = vals[i * 4 + channel];
}
}
}
var outShape = [height, width, numChannels];
return tensor3d(values, outShape, "int32");
}
function toPixels(img, canvas) {
return __awaiter(this, void 0, void 0, function() {
var $img, originalImgTensor, _a, height, width, depth, data, multiplier, bytes, i, rgba, d, value, j, ctx, imageData;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
$img = convertToTensor(img, "img", "toPixels");
if (!(img instanceof Tensor)) {
originalImgTensor = $img;
$img = cast(originalImgTensor, "int32");
originalImgTensor.dispose();
}
if ($img.rank !== 2 && $img.rank !== 3) {
throw new Error("toPixels only supports rank 2 or 3 tensors, got rank " + $img.rank + ".");
}
_a = $img.shape.slice(0, 2), height = _a[0], width = _a[1];
depth = $img.rank === 2 ? 1 : $img.shape[2];
if (depth > 4 || depth === 2) {
throw new Error("toPixels only supports depth of size " + ("1, 3 or 4 but got " + depth));
}
if ($img.dtype !== "float32" && $img.dtype !== "int32") {
throw new Error("Unsupported type for toPixels: " + $img.dtype + ". Please use float32 or int32 tensors.");
}
return [4, $img.data()];
case 1:
data = _b.sent();
multiplier = $img.dtype === "float32" ? 255 : 1;
bytes = new Uint8ClampedArray(width * height * 4);
for (i = 0; i < height * width; ++i) {
rgba = [0, 0, 0, 255];
for (d = 0; d < depth; d++) {
value = data[i * depth + d];
if ($img.dtype === "float32") {
if (value < 0 || value > 1) {
throw new Error("Tensor values for a float32 Tensor must be in the " + ("range [0 - 1] but encountered " + value + "."));
}
} else if ($img.dtype === "int32") {
if (value < 0 || value > 255) {
throw new Error("Tensor values for a int32 Tensor must be in the " + ("range [0 - 255] but encountered " + value + "."));
}
}
if (depth === 1) {
rgba[0] = value * multiplier;
rgba[1] = value * multiplier;
rgba[2] = value * multiplier;
} else {
rgba[d] = value * multiplier;
}
}
j = i * 4;
bytes[j + 0] = Math.round(rgba[0]);
bytes[j + 1] = Math.round(rgba[1]);
bytes[j + 2] = Math.round(rgba[2]);
bytes[j + 3] = Math.round(rgba[3]);
}
if (canvas != null) {
canvas.width = width;
canvas.height = height;
ctx = canvas.getContext("2d");
imageData = new ImageData(bytes, width, height);
ctx.putImageData(imageData, 0, 0);
}
if ($img !== img) {
$img.dispose();
}
return [2, bytes];
}
});
});
}
var fromPixels = op({fromPixels_});
var browser = {
__proto__: null,
toPixels,
fromPixels
};
function prepareAndValidate(tensor2, indices) {
if (tensor2.rank < 1) {
throw new Error("tf.gatherND() expects the input to be rank 1 or higher," + (" but the rank was " + tensor2.rank + "."));
}
if (indices.rank < 1) {
throw new Error("tf.gatherND() expects the indices to be rank 1 or higher," + (" but the rank was " + indices.rank + "."));
}
if (indices.dtype !== "int32") {
throw new Error("tf.gatherND() expects the indices to be int32 type," + (" but the dtype was " + indices.dtype + "."));
}
if (indices.shape[indices.rank - 1] > tensor2.rank) {
throw new Error("index innermost dimension length must be <= tensor rank; saw: " + (indices.shape[indices.rank - 1] + " vs. " + tensor2.rank));
}
if (tensor2.size === 0) {
throw new Error("Requested more than 0 entries, but input is empty." + (" Input shape: " + tensor2.shape + "."));
}
var indicesShape = indices.shape;
var sliceRank = indicesShape[indicesShape.length - 1];
var nResult = 1;
for (var i = 0; i < indicesShape.length - 1; ++i) {
nResult *= indicesShape[i];
}
var inputShape = tensor2.shape;
var resultShape = indicesShape.slice();
resultShape.pop();
var sliceSize = 1;
for (var i = sliceRank; i < tensor2.rank; ++i) {
sliceSize *= inputShape[i];
resultShape.push(inputShape[i]);
}
var strides = computeStrides(tensor2.shape).map(function(stride) {
return stride / sliceSize;
}).concat([1]).slice(0, sliceRank);
return [resultShape, nResult, sliceSize, strides];
}
var gather_nd_util = {
__proto__: null,
prepareAndValidate
};
function validateUpdateShape(shape, indices, updates) {
var sliceDim = indices.rank > 1 ? indices.shape[indices.rank - 1] : 1;
var batchDim = indices.rank > 1 ? indices.rank - 1 : 1;
var shapeError = "Must have updates.shape = indices.shape[:batchDim] + " + ("shape[sliceDim:], got updates.shape: " + updates.shape) + (", indices.shape: " + indices.shape + ", shape: " + shape) + (", sliceDim: " + sliceDim + ", and batchDim: " + batchDim + ".");
if (updates.rank < batchDim) {
throw new Error(shapeError + (" update.rank < " + batchDim + ". "));
}
if (shape.length < sliceDim + (updates.rank - batchDim)) {
throw new Error(shapeError + (" Output shape length < " + (sliceDim + (updates.rank - batchDim))));
}
if (updates.rank !== batchDim + shape.length - sliceDim) {
throw new Error(shapeError + (" update.rank != " + (batchDim + shape.length - sliceDim)));
}
for (var d = 0; d < batchDim; ++d) {
if (updates.shape[d] !== indices.shape[d]) {
throw new Error(shapeError + (" updates.shape[" + d + "] (" + updates.shape[d] + ") != indices.shape[" + d + "] (" + indices.shape[d] + ")."));
}
}
for (var d = 0; d < updates.rank - batchDim; ++d) {
if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
throw new Error(shapeError + (" updates.shape[" + (d + batchDim) + "] (" + updates.shape[d + batchDim] + ") != shape[" + (d + batchDim) + "] (" + shape[d + batchDim] + ")"));
}
}
}
function validateInput(updates, indices, shape) {
if (indices.rank < 1) {
throw new Error("tf.scatterND() expects the indices to be rank 1 or higher," + (" but the rank was " + indices.rank + "."));
}
if (updates.rank < 1) {
throw new Error("tf.scatterND() expects the updates to be rank 1 or higher," + (" but the rank was " + updates.rank + "."));
}
if (indices.dtype !== "int32") {
throw new Error("The dtype of 'indices' should be int32, but got dtype: " + indices.dtype);
}
if (shape.length < 1) {
throw new Error("Output rank must be greater or equal to 1, but got shape: " + shape);
}
if (shape.length === 0) {
if (indices.size === 0) {
throw new Error("Indices specified for empty output. indices shape: " + indices.shape);
}
if (updates.size === 0) {
throw new Error("Updates specified for empty output. updates shape: " + updates.shape);
}
}
validateUpdateShape(shape, indices, updates);
}
function calculateShapes(updates, indices, shape) {
var indicesRank = indices.shape.length;
var sliceRank = indicesRank > 1 ? indices.shape[indicesRank - 1] : 1;
var totalNd = shape.length;
var sliceSize = 1;
for (var i = sliceRank; i < totalNd; ++i) {
sliceSize *= shape[i];
}
var safeSliceDim = sliceRank < 1 ? 1 : sliceRank;
var numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
var strides = computeStrides(shape.slice(0, sliceRank)).concat([1]);
var outputSize = sizeFromShape(shape);
return {sliceRank, numUpdates, sliceSize, strides, outputSize};
}
var scatter_nd_util = {
__proto__: null,
validateUpdateShape,
validateInput,
calculateShapes
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function assertParamsValid(input, begin, size) {
var inputRank = input.shape.length;
assert(inputRank === begin.length, function() {
return "Error in slice" + inputRank + "D: Length of begin " + begin + " must " + ("match the rank of the array (" + inputRank + ").");
});
assert(inputRank === size.length, function() {
return "Error in slice" + inputRank + "D: Length of size " + size + " must " + ("match the rank of the array (" + inputRank + ").");
});
var _loop_1 = function(i2) {
assert(begin[i2] + size[i2] <= input.shape[i2], function() {
return "Error in slice" + inputRank + "D: begin[" + i2 + "] + size[" + i2 + "] " + ("(" + (begin[i2] + size[i2]) + ") would overflow input.shape[" + i2 + "] (" + input.shape[i2] + ")");
});
};
for (var i = 0; i < inputRank; ++i) {
_loop_1(i);
}
}
function maskToAxes(mask) {
var axes = [];
var axis = 0;
while (mask > 0) {
if (mask & 1) {
axes.push(axis);
}
mask /= 2;
axis++;
}
return axes;
}
function computeOutShape(begin, end, strides) {
var size = [];
for (var axis = 0; axis < begin.length; axis++) {
size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
}
return size;
}
function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
var newStrides = strides.slice();
for (var i = newStrides.length; i < inputShape.length; i++) {
newStrides.push(1);
}
for (var i = 0; i < numElidedAxes; i++) {
if (i === 0) {
newStrides[ellipsisInsertionIndex] = 1;
} else {
newStrides.splice(ellipsisInsertionIndex, 0, 1);
newStrides.pop();
}
}
return newStrides;
}
function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
if (normalizedAxis <= ellipsisInsertionIndex) {
return normalizedAxis;
}
return normalizedAxis - (numElidedAxes - 1);
}
function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
var elidedAxes = [];
for (var i = 0; i < numElidedAxes; i++) {
elidedAxes.push(ellipsisInsertionIndex + i);
}
return elidedAxes;
}
function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
var inputRank = inputShape.length;
var normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), normalizedStrides = new Array(inputRank);
if (ellipsisAxes.length && numInterpolatedAxes > 0) {
var fullIndex = ellipsisAxes[0];
var numElidedAxes = numInterpolatedAxes + 1;
normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
normalizedStrides = stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
} else {
for (var axis = 0; axis < inputRank; axis++) {
normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
normalizedEnd[axis] = stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
}
}
return {
begin: normalizedBegin,
end: normalizedEnd,
strides: normalizedStrides
};
}
function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
var newIndices = inputShape.slice();
var elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
for (var axis = 0; axis < newIndices.length; axis++) {
if (elidedAxes.indexOf(axis) > -1) {
newIndices[axis] = 0;
} else {
var originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
var originalValue = originalBegin[originalAxis];
if (beginMask & 1 << originalAxis) {
originalValue = 0;
}
newIndices[axis] = originalValue;
}
}
return newIndices;
}
function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
var newIndices = inputShape.slice();
var elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
for (var axis = 0; axis < newIndices.length; axis++) {
if (elidedAxes.indexOf(axis) > -1) {
newIndices[axis] = Number.MAX_SAFE_INTEGER;
} else {
var originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
var originalValue = originalEnd[originalAxis];
if (endMask & 1 << originalAxis) {
originalValue = Number.MAX_SAFE_INTEGER;
}
newIndices[axis] = originalValue;
}
}
for (var i = 0; i < newIndices.length; i++) {
var axisSize = inputShape[i];
if (newIndices[i] < 0) {
newIndices[i] += axisSize;
}
newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
}
return newIndices;
}
function stridesForAxis(strides, axis, ellipsisMask) {
var stride = strides[axis];
if (ellipsisMask & 1 << axis || stride == null) {
stride = 1;
}
return stride;
}
function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
var start = startIndices[axis];
var stride = strides[axis] || 1;
if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
if (stride > 0) {
start = Number.MIN_SAFE_INTEGER;
} else {
start = Number.MAX_SAFE_INTEGER;
}
}
var axisSize = inputShape[axis];
if (start < 0) {
start += axisSize;
}
start = clamp(0, start, axisSize - 1);
return start;
}
function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
var stop = stopIndices[axis];
var stride = strides[axis] || 1;
if (endMask & 1 << axis || ellipsisMask & 1 << axis || stop == null) {
if (stride > 0) {
stop = Number.MAX_SAFE_INTEGER;
} else {
stop = Number.MIN_SAFE_INTEGER;
}
}
var axisSize = inputShape[axis];
if (stop < 0) {
stop += axisSize;
}
if (stride > 0) {
stop = clamp(0, stop, axisSize);
} else {
stop = clamp(-1, stop, axisSize - 1);
}
return stop;
}
function isSliceContinous(shape, begin, size) {
var firstNonOneAxis = size.length;
for (var i = 0; i < size.length; i++) {
if (size[i] > 1) {
firstNonOneAxis = i;
break;
}
}
for (var i = firstNonOneAxis + 1; i < size.length; i++) {
if (begin[i] > 0 || size[i] !== shape[i]) {
return false;
}
}
return true;
}
function computeFlatOffset(begin, strides) {
var flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
for (var i = 0; i < begin.length - 1; i++) {
flatOffset += begin[i] * strides[i];
}
return flatOffset;
}
function parseSliceParams(x, begin, size) {
var begin_;
var xRank = x.shape.length;
if (typeof begin === "number") {
begin_ = [begin].concat(new Array(xRank - 1).fill(0));
} else if (begin.length < xRank) {
begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
} else {
begin_ = begin.slice();
}
begin_.forEach(function(d) {
assert(d !== -1, function() {
return "slice() does not support negative begin indexing.";
});
});
var size_;
if (size == null) {
size_ = new Array(xRank).fill(-1);
} else if (typeof size === "number") {
size_ = [size].concat(new Array(xRank - 1).fill(-1));
} else if (size.length < xRank) {
size_ = size.concat(new Array(xRank - size.length).fill(-1));
} else {
size_ = size;
}
size_ = size_.map(function(d, i) {
if (d >= 0) {
return d;
} else {
assert(d === -1, function() {
return "Negative size values should be exactly -1 but got " + (d + " for the slice() size at index " + i + ".");
});
return x.shape[i] - begin_[i];
}
});
return [begin_, size_];
}
var slice_util = {
__proto__: null,
assertParamsValid,
maskToAxes,
computeOutShape,
stridesWithElidedDims,
getNormalizedAxes,
startIndicesWithElidedDims,
stopIndicesWithElidedDims,
stridesForAxis,
startForAxis,
stopForAxis,
isSliceContinous,
computeFlatOffset,
parseSliceParams
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Serializable = function() {
function Serializable2() {
}
Serializable2.prototype.getClassName = function() {
return this.constructor.className;
};
Serializable2.fromConfig = function(cls, config) {
return new cls(config);
};
return Serializable2;
}();
var SerializationMap = function() {
function SerializationMap2() {
this.classNameMap = {};
}
SerializationMap2.getMap = function() {
if (SerializationMap2.instance == null) {
SerializationMap2.instance = new SerializationMap2();
}
return SerializationMap2.instance;
};
SerializationMap2.register = function(cls) {
SerializationMap2.getMap().classNameMap[cls.className] = [cls, cls.fromConfig];
};
return SerializationMap2;
}();
function registerClass(cls) {
assert(cls.className != null, function() {
return "Class being registered does not have the static className property defined.";
});
assert(typeof cls.className === "string", function() {
return "className is required to be a string, but got type " + typeof cls.className;
});
assert(cls.className.length > 0, function() {
return "Class being registered has an empty-string as its className, which is disallowed.";
});
SerializationMap.register(cls);
}
var serialization = {
__proto__: null,
Serializable,
SerializationMap,
registerClass
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TEST_EPSILON_FLOAT32 = 1e-3;
var TEST_EPSILON_FLOAT16 = 0.1;
function expectArraysClose(actual, expected, epsilon) {
if (epsilon == null) {
epsilon = testEpsilon();
}
return expectArraysPredicate(actual, expected, function(a, b) {
return areClose(a, b, epsilon);
});
}
function testEpsilon() {
return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 : TEST_EPSILON_FLOAT16;
}
function expectArraysPredicate(actual, expected, predicate) {
var checkClassType = true;
if (isTypedArray(actual) || isTypedArray(expected)) {
checkClassType = false;
}
if (isTypedArray(actual) && isTypedArray(expected)) {
checkClassType = true;
}
if (checkClassType) {
var aType = actual.constructor.name;
var bType = expected.constructor.name;
if (aType !== bType) {
throw new Error("Arrays are of different type. Actual: " + aType + ". " + ("Expected: " + bType));
}
}
if (Array.isArray(actual) && Array.isArray(expected)) {
var actualShape = inferShape(actual);
var expectedShape = inferShape(expected);
if (!arraysEqual(actualShape, expectedShape)) {
throw new Error("Arrays have different shapes. " + ("Actual: [" + actualShape + "]. Expected: [" + expectedShape + "]"));
}
}
var actualFlat = isTypedArray(actual) ? actual : flatten(actual);
var expectedFlat = isTypedArray(expected) ? expected : flatten(expected);
if (actualFlat.length !== expectedFlat.length) {
throw new Error("Arrays have different lengths actual: " + actualFlat.length + " vs " + ("expected: " + expectedFlat.length + ".\n") + ("Actual: " + actualFlat + ".\n") + ("Expected: " + expectedFlat + "."));
}
for (var i = 0; i < expectedFlat.length; ++i) {
var a = actualFlat[i];
var e = expectedFlat[i];
if (!predicate(a, e)) {
throw new Error("Arrays differ: actual[" + i + "] = " + a + ", expected[" + i + "] = " + e + ".\n" + ("Actual: " + actualFlat + ".\n") + ("Expected: " + expectedFlat + "."));
}
}
}
function expectPromiseToFail(fn, done) {
fn().then(function() {
return done.fail();
}, function() {
return done();
});
}
function expectArraysEqual(actual, expected) {
var exp2 = typeof expected === "string" || typeof expected === "number" || typeof expected === "boolean" ? [expected] : expected;
if (isString(actual) || isString(actual[0]) || isString(expected) || isString(expected[0])) {
return expectArraysPredicate(actual, exp2, function(a, b) {
return a == b;
});
}
return expectArraysPredicate(actual, expected, function(a, b) {
return areClose(a, b, 0);
});
}
function expectNumbersClose(a, e, epsilon) {
if (epsilon == null) {
epsilon = testEpsilon();
}
if (!areClose(a, e, epsilon)) {
throw new Error("Numbers differ: actual === " + a + ", expected === " + e);
}
}
function areClose(a, e, epsilon) {
if (!isFinite(a) && !isFinite(e)) {
return true;
}
if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
return false;
}
return true;
}
function expectValuesInRange(actual, low, high) {
for (var i = 0; i < actual.length; i++) {
if (actual[i] < low || actual[i] > high) {
throw new Error("Value out of range:" + actual[i] + " low: " + low + ", high: " + high);
}
}
}
function expectArrayBuffersEqual(actual, expected) {
expect(new Float32Array(actual)).toEqual(new Float32Array(expected));
}
var test_util = {
__proto__: null,
TEST_EPSILON_FLOAT16,
expectArraysClose,
testEpsilon,
expectPromiseToFail,
expectArraysEqual,
expectNumbersClose,
expectValuesInRange,
expectArrayBuffersEqual
};
/** @license See the LICENSE file. */
var version = "2.7.0";
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function enableProdMode() {
env().set("PROD", true);
}
function enableDebugMode() {
env().set("DEBUG", true);
}
function disableDeprecationWarnings() {
env().set("DEPRECATION_WARNINGS_ENABLED", false);
console.warn("TensorFlow.js deprecation warnings have been disabled.");
}
function deprecationWarn(msg) {
if (env().getBool("DEPRECATION_WARNINGS_ENABLED")) {
console.warn(msg + " You can disable deprecation warnings with tf.disableDeprecationWarnings().");
}
}
function disposeVariables() {
ENGINE.disposeVariables();
}
function engine() {
return ENGINE;
}
function memory() {
return ENGINE.memory();
}
function profile(f) {
return ENGINE.profile(f);
}
function tidy(nameOrFn, fn) {
return ENGINE.tidy(nameOrFn, fn);
}
function dispose(container) {
var tensors = getTensorsInContainer(container);
tensors.forEach(function(tensor2) {
return tensor2.dispose();
});
}
function keep(result) {
return ENGINE.keep(result);
}
function time(f) {
return ENGINE.time(f);
}
function setBackend(backendName) {
return ENGINE.setBackend(backendName);
}
function ready() {
return ENGINE.ready();
}
function getBackend() {
return ENGINE.backendName;
}
function removeBackend(name) {
ENGINE.removeBackend(name);
}
function findBackend(name) {
return ENGINE.findBackend(name);
}
function findBackendFactory(name) {
return ENGINE.findBackendFactory(name);
}
function registerBackend(name, factory, priority) {
if (priority === void 0) {
priority = 1;
}
return ENGINE.registerBackend(name, factory, priority);
}
function backend() {
return ENGINE.backend;
}
function setPlatform(platformName, platform) {
env().setPlatform(platformName, platform);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function add_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "add");
var $b = convertToTensor(b, "b", "add");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forward = function(backend2, save) {
var res = backend2.add($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Add);
}
var add$1 = op({add_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function floorDiv_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "floorDiv");
var $b = convertToTensor(b, "b", "floorDiv");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forward = function(backend2, save) {
var res = backend2.floorDiv($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, FloorDiv);
}
var floorDiv = op({floorDiv_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function div_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "div");
var $b = convertToTensor(b, "b", "div");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
if ($a.dtype === "int32" && $b.dtype === "int32") {
return floorDiv($a, $b);
}
var forward = function(backend2, save) {
var res = backend2.realDivide($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
var attrs = {};
return ENGINE.runKernelFunc(forward, inputs, null, Div, attrs);
}
var div = op({div_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function mul_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "mul");
var $b = convertToTensor(b, "b", "mul");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forward = function(backend2, save) {
var res = backend2.multiply($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Multiply);
}
var mul = op({mul_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function abs_(x) {
var $x = convertToTensor(x, "x", "abs");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
save([$x]);
if ($x.dtype === "complex64") {
return backend2.complexAbs($x);
}
return backend2.abs($x);
}, inputs, null, Abs);
}
var abs = op({abs_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function acos_(x) {
var $x = convertToTensor(x, "x", "acos");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.acos($x);
save([$x]);
return res;
}, inputs, null, Acos);
}
var acos = op({acos_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function acosh_(x) {
var $x = convertToTensor(x, "x", "acosh");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.acosh($x);
save([$x]);
return res;
}, inputs, null, Acosh);
}
var acosh = op({acosh_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function addN_(tensors) {
assert(Array.isArray(tensors), function() {
return "The argument passed to tf.addN() must be a list of tensors";
});
assert(tensors.length >= 1, function() {
return "Must pass at least one tensor to tf.addN(), but got " + ("" + tensors.length);
});
var $tensors = tensors.map(function(t, i) {
return convertToTensor(t, "tensors" + i, "addN");
});
var firstTensor = $tensors[0];
$tensors.forEach(function(t) {
if (t.dtype !== firstTensor.dtype) {
throw new Error("All tensors passed to tf.addN() must have the same dtype");
}
});
$tensors.forEach(function(t) {
if (!arraysEqual(t.shape, firstTensor.shape)) {
throw new Error("All tensors passed to tf.addN() must have the same shape");
}
});
var forward = function(backend2, save) {
var res = backend2.addN($tensors);
save($tensors);
return res;
};
var inputs = $tensors;
return ENGINE.runKernelFunc(forward, inputs, null, AddN);
}
var addN = op({addN_});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function axesAreInnerMostDims(axes, rank) {
for (var i = 0; i < axes.length; ++i) {
if (axes[axes.length - i - 1] !== rank - 1 - i) {
return false;
}
}
return true;
}
function combineLocations(outputLoc, reduceLoc, axes) {
var rank = outputLoc.length + reduceLoc.length;
var loc = [];
var outIdx = 0;
var reduceIdx = 0;
for (var dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
loc.push(outputLoc[outIdx++]);
} else {
loc.push(reduceLoc[reduceIdx++]);
}
}
return loc;
}
function computeOutAndReduceShapes(aShape, axes) {
var outShape = [];
var rank = aShape.length;
for (var dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
outShape.push(aShape[dim]);
}
}
var reduceShape = axes.map(function(dim2) {
return aShape[dim2];
});
return [outShape, reduceShape];
}
function expandShapeToKeepDim(shape, axes) {
var reduceSubShape = axes.map(function(x) {
return 1;
});
return combineLocations(shape, reduceSubShape, axes);
}
function assertAxesAreInnerMostDims(msg, axes, rank) {
assert(axesAreInnerMostDims(axes, rank), function() {
return msg + " supports only inner-most axes for now. " + ("Got axes " + axes + " and rank-" + rank + " input.");
});
}
function getAxesPermutation(axes, rank) {
if (axesAreInnerMostDims(axes, rank)) {
return null;
}
var result = [];
for (var i = 0; i < rank; ++i) {
if (axes.indexOf(i) === -1) {
result.push(i);
}
}
axes.forEach(function(axis) {
return result.push(axis);
});
return result;
}
function getUndoAxesPermutation(axes) {
return axes.map(function(axis, i) {
return [i, axis];
}).sort(function(a, b) {
return a[1] - b[1];
}).map(function(x) {
return x[0];
});
}
function getInnerMostAxes(numAxes, rank) {
var res = [];
for (var i = rank - numAxes; i < rank; ++i) {
res.push(i);
}
return res;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function all_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, "x", "all", "bool");
var forward = function(backend2) {
var origAxes = parseAxisParam(axis, $x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, $x.rank);
if (permutedAxes != null) {
$x = transpose($x, permutedAxes);
axes = getInnerMostAxes(axes.length, $x.rank);
}
var res = backend2.all($x, axes);
if (keepDims) {
var newShape = expandShapeToKeepDim(res.shape, origAxes);
return reshape(res, newShape);
}
return res;
};
var inputs = {x: $x};
var attrs = {axis, keepDims};
return ENGINE.runKernelFunc(forward, inputs, null, All, attrs);
}
var all = op({all_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function any_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, "x", "any", "bool");
var forward = function(backend2) {
var origAxes = parseAxisParam(axis, $x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, $x.rank);
if (permutedAxes != null) {
$x = transpose($x, permutedAxes);
axes = getInnerMostAxes(axes.length, $x.rank);
}
var res = backend2.any($x, axes);
if (keepDims) {
var newShape = expandShapeToKeepDim(res.shape, origAxes);
return reshape(res, newShape);
}
return res;
};
var inputs = {x: $x};
var attrs = {axis, keepDims};
return ENGINE.runKernelFunc(forward, inputs, null, Any, attrs);
}
var any = op({any_});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function argMax_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, "x", "argMax");
var forward = function(backend2, save) {
save([$x]);
var axes = parseAxisParam(axis, $x.shape);
var permutedAxes = getAxesPermutation(axes, $x.rank);
if (permutedAxes != null) {
$x = transpose($x, permutedAxes);
axes = getInnerMostAxes(axes.length, $x.rank);
}
return backend2.argMax($x, axes[0]);
};
var inputs = {x: $x};
var attrs = {axis};
return ENGINE.runKernelFunc(forward, inputs, null, ArgMax, attrs);
}
var argMax = op({argMax_});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function argMin_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, "x", "argMin");
var forward = function(backend2, save) {
save([$x]);
if (axis == null) {
axis = 0;
}
var axes = parseAxisParam(axis, $x.shape);
var permutedAxes = getAxesPermutation(axes, $x.rank);
if (permutedAxes != null) {
$x = transpose($x, permutedAxes);
axes = getInnerMostAxes(axes.length, $x.rank);
}
return backend2.argMin($x, axes[0]);
};
var inputs = {x: $x};
var attrs = {axis};
return ENGINE.runKernelFunc(forward, inputs, null, ArgMin, attrs);
}
var argMin = op({argMin_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function asin_(x) {
var $x = convertToTensor(x, "x", "asin");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.asin($x);
save([$x]);
return res;
}, inputs, null, Asin);
}
var asin = op({asin_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function asinh_(x) {
var $x = convertToTensor(x, "x", "asinh");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.asinh($x);
save([$x]);
return res;
}, inputs, null, Asinh);
}
var asinh = op({asinh_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function atan_(x) {
var $x = convertToTensor(x, "x", "atan");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.atan($x);
save([$x]);
return res;
}, inputs, null, Atan);
}
var atan = op({atan_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function atan2_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "atan2");
var $b = convertToTensor(b, "b", "atan2");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forward = function(backend2, save) {
var res = backend2.atan2($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Atan2);
}
var atan2 = op({atan2_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function atanh_(x) {
var $x = convertToTensor(x, "x", "atanh");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.atanh($x);
save([$x]);
return res;
}, inputs, null, Atanh);
}
var atanh = op({atanh_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function computeDilation2DInfo(inputShape, filterShape, strides, pad2, dataFormat, dilations) {
if (dataFormat === void 0) {
dataFormat = "NHWC";
}
var inputChannels = inputShape[3];
var $filterShape = filterShape.concat([inputChannels]);
var $dataFormat = convertConv2DDataFormat(dataFormat);
return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad2, null, null, $dataFormat);
}
function computePool2DInfo(inShape, filterSize, strides, dilations, pad2, roundingMode, dataFormat) {
if (dataFormat === void 0) {
dataFormat = "channelsLast";
}
var _a = parseTupleParam(filterSize), filterHeight = _a[0], filterWidth = _a[1];
var filterShape;
if (dataFormat === "channelsLast") {
filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
} else if (dataFormat === "channelsFirst") {
filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
return computeConv2DInfo(inShape, filterShape, strides, dilations, pad2, roundingMode, false, dataFormat);
}
function computePool3DInfo(inShape, filterSize, strides, dilations, pad2, roundingMode, dataFormat) {
if (dataFormat === void 0) {
dataFormat = "NDHWC";
}
var _a = parse3TupleParam(filterSize), filterDepth = _a[0], filterHeight = _a[1], filterWidth = _a[2];
var filterShape;
var $dataFormat;
if (dataFormat === "NDHWC") {
$dataFormat = "channelsLast";
filterShape = [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
} else if (dataFormat === "NCDHW") {
$dataFormat = "channelsFirst";
filterShape = [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
return computeConv3DInfo(inShape, filterShape, strides, dilations, pad2, false, $dataFormat, roundingMode);
}
function computeConv2DInfo(inShape, filterShape, strides, dilations, pad2, roundingMode, depthwise, dataFormat) {
if (depthwise === void 0) {
depthwise = false;
}
if (dataFormat === void 0) {
dataFormat = "channelsLast";
}
var _a = [-1, -1, -1, -1], batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3];
if (dataFormat === "channelsLast") {
batchSize = inShape[0], inHeight = inShape[1], inWidth = inShape[2], inChannels = inShape[3];
} else if (dataFormat === "channelsFirst") {
batchSize = inShape[0], inChannels = inShape[1], inHeight = inShape[2], inWidth = inShape[3];
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
var filterHeight = filterShape[0], filterWidth = filterShape[1], filterChannels = filterShape[3];
var _b = parseTupleParam(strides), strideHeight = _b[0], strideWidth = _b[1];
var _c = parseTupleParam(dilations), dilationHeight = _c[0], dilationWidth = _c[1];
var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
var _d = getPadAndOutInfo(pad2, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat), padInfo = _d.padInfo, outHeight = _d.outHeight, outWidth = _d.outWidth;
var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
var outShape;
if (dataFormat === "channelsFirst") {
outShape = [batchSize, outChannels, outHeight, outWidth];
} else if (dataFormat === "channelsLast") {
outShape = [batchSize, outHeight, outWidth, outChannels];
}
return {
batchSize,
dataFormat,
inHeight,
inWidth,
inChannels,
outHeight,
outWidth,
outChannels,
padInfo,
strideHeight,
strideWidth,
filterHeight,
filterWidth,
effectiveFilterHeight,
effectiveFilterWidth,
dilationHeight,
dilationWidth,
inShape,
outShape,
filterShape
};
}
function computeConv3DInfo(inShape, filterShape, strides, dilations, pad2, depthwise, dataFormat, roundingMode) {
if (depthwise === void 0) {
depthwise = false;
}
if (dataFormat === void 0) {
dataFormat = "channelsLast";
}
var _a = [-1, -1, -1, -1, -1], batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4];
if (dataFormat === "channelsLast") {
batchSize = inShape[0], inDepth = inShape[1], inHeight = inShape[2], inWidth = inShape[3], inChannels = inShape[4];
} else if (dataFormat === "channelsFirst") {
batchSize = inShape[0], inChannels = inShape[1], inDepth = inShape[2], inHeight = inShape[3], inWidth = inShape[4];
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
var filterDepth = filterShape[0], filterHeight = filterShape[1], filterWidth = filterShape[2], filterChannels = filterShape[4];
var _b = parse3TupleParam(strides), strideDepth = _b[0], strideHeight = _b[1], strideWidth = _b[2];
var _c = parse3TupleParam(dilations), dilationDepth = _c[0], dilationHeight = _c[1], dilationWidth = _c[2];
var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
var _d = get3DPadAndOutInfo(pad2, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outDepth = _d.outDepth, outHeight = _d.outHeight, outWidth = _d.outWidth;
var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
var outShape;
if (dataFormat === "channelsFirst") {
outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
} else if (dataFormat === "channelsLast") {
outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
}
return {
batchSize,
dataFormat,
inDepth,
inHeight,
inWidth,
inChannels,
outDepth,
outHeight,
outWidth,
outChannels,
padInfo,
strideDepth,
strideHeight,
strideWidth,
filterDepth,
filterHeight,
filterWidth,
effectiveFilterDepth,
effectiveFilterHeight,
effectiveFilterWidth,
dilationDepth,
dilationHeight,
dilationWidth,
inShape,
outShape,
filterShape
};
}
function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
}
var inputRows = inShape[0];
var inputCols = inShape[1];
var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
assert(isInt(outputRows), function() {
return "The output # of rows (" + outputRows + ") must be an integer. Change the stride and/or zero pad parameters";
});
var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
assert(isInt(outputCols), function() {
return "The output # of columns (" + outputCols + ") must be an integer. Change the stride and/or zero pad parameters";
});
return [outputRows, outputCols];
}
function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
}
var inputDepth = inShape[0];
var inputRows = inShape[1];
var inputCols = inShape[2];
var outputDepths = conditionalRound((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
assert(isInt(outputDepths), function() {
return "The output # of depths (" + outputDepths + ") must be an integer. Change the stride and/or zero pad parameters";
});
var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
assert(isInt(outputRows), function() {
return "The output # of rows (" + outputRows + ") must be an integer. Change the stride and/or zero pad parameters";
});
var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
assert(isInt(outputCols), function() {
return "The output # of columns (" + outputCols + ") must be an integer. Change the stride and/or zero pad parameters";
});
return [outputDepths, outputRows, outputCols, outChannels];
}
function computeDefaultPad(inputShape, fieldSize, stride, dilation) {
if (dilation === void 0) {
dilation = 1;
}
var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
}
function parseTupleParam(param) {
if (typeof param === "number") {
return [param, param, param];
}
if (param.length === 2) {
return [param[0], param[1], 1];
}
return param;
}
function parse3TupleParam(param) {
return typeof param === "number" ? [param, param, param] : param;
}
function getEffectiveFilterSize(filterSize, dilation) {
if (dilation <= 1) {
return filterSize;
}
return filterSize + (filterSize - 1) * (dilation - 1);
}
function getPadAndOutInfo(pad2, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
var padInfo;
var outHeight;
var outWidth;
if (typeof pad2 === "number") {
var padType = pad2 === 0 ? "VALID" : "NUMBER";
padInfo = {top: pad2, bottom: pad2, left: pad2, right: pad2, type: padType};
var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad2, roundingMode);
outHeight = outShape[0];
outWidth = outShape[1];
} else if (pad2 === "same") {
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
var top_1 = Math.floor(padAlongHeight / 2);
var bottom = padAlongHeight - top_1;
var left = Math.floor(padAlongWidth / 2);
var right = padAlongWidth - left;
padInfo = {top: top_1, bottom, left, right, type: "SAME"};
} else if (pad2 === "valid") {
padInfo = {top: 0, bottom: 0, left: 0, right: 0, type: "VALID"};
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
} else if (typeof pad2 === "object") {
var top_2 = dataFormat === "channelsLast" ? pad2[1][0] : pad2[2][0];
var bottom = dataFormat === "channelsLast" ? pad2[1][1] : pad2[2][1];
var left = dataFormat === "channelsLast" ? pad2[2][0] : pad2[3][0];
var right = dataFormat === "channelsLast" ? pad2[2][1] : pad2[3][1];
var padType = top_2 === 0 && bottom === 0 && left === 0 && right === 0 ? "VALID" : "EXPLICIT";
padInfo = {top: top_2, bottom, left, right, type: padType};
outHeight = conditionalRound((inHeight - filterHeight + top_2 + bottom) / strideHeight + 1, roundingMode);
outWidth = conditionalRound((inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode);
} else {
throw Error("Unknown padding parameter: " + pad2);
}
return {padInfo, outHeight, outWidth};
}
function get3DPadAndOutInfo(pad2, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
var padInfo;
var outDepth;
var outHeight;
var outWidth;
if (typeof pad2 === "number") {
var padType = pad2 === 0 ? "VALID" : "NUMBER";
padInfo = {
top: pad2,
bottom: pad2,
left: pad2,
right: pad2,
front: pad2,
back: pad2,
type: padType
};
var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad2, roundingMode);
outDepth = outShape[0];
outHeight = outShape[1];
outWidth = outShape[2];
} else if (pad2 === "same") {
outDepth = Math.ceil(inDepth / strideDepth);
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
var front = Math.floor(padAlongDepth / 2);
var back = padAlongDepth - front;
var top_3 = Math.floor(padAlongHeight / 2);
var bottom = padAlongHeight - top_3;
var left = Math.floor(padAlongWidth / 2);
var right = padAlongWidth - left;
padInfo = {top: top_3, bottom, left, right, front, back, type: "SAME"};
} else if (pad2 === "valid") {
padInfo = {
top: 0,
bottom: 0,
left: 0,
right: 0,
front: 0,
back: 0,
type: "VALID"
};
outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
} else {
throw Error("Unknown padding parameter: " + pad2);
}
return {padInfo, outDepth, outHeight, outWidth};
}
function conditionalRound(value, roundingMode) {
if (!roundingMode) {
return value;
}
switch (roundingMode) {
case "round":
return Math.round(value);
case "ceil":
return Math.ceil(value);
case "floor":
return Math.floor(value);
default:
throw new Error("Unknown roundingMode " + roundingMode);
}
}
function tupleValuesAreOne(param) {
var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1], dimC = _a[2];
return dimA === 1 && dimB === 1 && dimC === 1;
}
function eitherStridesOrDilationsAreOne(strides, dilations) {
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}
function convertConv2DDataFormat(dataFormat) {
if (dataFormat === "NHWC") {
return "channelsLast";
} else if (dataFormat === "NCHW") {
return "channelsFirst";
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool_(x, filterSize, strides, pad2, dimRoundingMode) {
var $x = convertToTensor(x, "x", "avgPool", "float32");
var dilations = 1;
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in avgPool: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function() {
return "Error in avgPool: x must be rank 4 but got rank " + x4D.rank + ".";
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in avgPool: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2, save) {
var convInfo = computePool2DInfo(x4D.shape, filterSize, strides, 1, pad2, dimRoundingMode);
save([x4D]);
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
return x4D.clone();
}
return backend2.avgPool(x4D, convInfo);
};
var inputs = {x: x4D};
var attrs = {filterSize, strides, pad: pad2, dimRoundingMode};
var res = ENGINE.runKernelFunc(forward, inputs, null, AvgPool, attrs);
res = cast(res, $x.dtype);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var avgPool = op({avgPool_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool3d_(x, filterSize, strides, pad2, dimRoundingMode, dataFormat, dilations) {
if (dataFormat === void 0) {
dataFormat = "NDHWC";
}
if (dilations == null) {
dilations = [1, 1, 1];
} else {
deprecationWarn("dilations is deprecated, this field will be gone in v3.0.0.");
}
var $x = convertToTensor(x, "x", "avgPool3d", "float32");
var x5D = $x;
var reshapedTo5D = false;
if ($x.rank === 4) {
reshapedTo5D = true;
x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
}
assert(x5D.rank === 5, function() {
return "Error in avgPool3d: x must be rank 5 but got rank " + x5D.rank + ".";
});
assert(dataFormat === "NDHWC", function() {
return "Error in avgPool3d: Only NDHWC is currently supported, " + ("but got dataFormat of " + dataFormat);
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in avgPool3d: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in avgPool3d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2, save) {
if (dilations == null) {
dilations = [1, 1, 1];
}
var convInfo = computePool3DInfo(x5D.shape, filterSize, strides, dilations, pad2, dimRoundingMode, dataFormat);
save([x5D]);
return backend2.avgPool3d(x5D, convInfo);
};
var inputs = {x: x5D};
var attrs = {filterSize, strides, pad: pad2, dimRoundingMode, dataFormat, dilations};
var res = ENGINE.runKernelFunc(forward, inputs, null, AvgPool3D, attrs);
res = cast(res, x5D.dtype);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var avgPool3d = op({avgPool3d_});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function assertParamsConsistent(shapes, axis) {
var rank = shapes[0].length;
shapes.forEach(function(shape, i) {
assert(shape.length === rank, function() {
return "Error in concat" + rank + "D: rank of tensors[" + i + "] must be the same " + ("as the rank of the rest (" + rank + ")");
});
});
assert(axis >= 0 && axis < rank, function() {
return "Error in concat" + rank + "D: axis must be between 0 and " + (rank - 1) + ".";
});
var firstShape = shapes[0];
shapes.forEach(function(shape, i) {
for (var r = 0; r < rank; r++) {
assert(r === axis || shape[r] === firstShape[r], function() {
return "Error in concat" + rank + "D: Shape of tensors[" + i + "] (" + shape + ") " + ("does not match the shape of the rest (" + firstShape + ") ") + ("along the non-concatenated axis " + i + ".");
});
}
});
}
function computeOutShape$1(shapes, axis) {
var outputShape = shapes[0].slice();
for (var i = 1; i < shapes.length; i++) {
outputShape[axis] += shapes[i][axis];
}
return outputShape;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function concat_(tensors, axis) {
if (axis === void 0) {
axis = 0;
}
assert(tensors.length >= 1, function() {
return "Pass at least one tensor to concat";
});
var $tensors = convertToTensorArray(tensors, "tensors", "concat");
if ($tensors[0].dtype === "complex64") {
$tensors.forEach(function(tensor2) {
if (tensor2.dtype !== "complex64") {
throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype " + tensor2.dtype + ". ");
}
});
}
var forward = function(backend2, save) {
var $axis = parseAxisParam(axis, $tensors[0].shape)[0];
var outShape = computeOutShape$1($tensors.map(function(t) {
return t.shape;
}), $axis);
if (sizeFromShape(outShape) === 0) {
return tensor([], outShape);
}
$tensors = $tensors.filter(function(t) {
return t.size > 0;
});
if ($tensors.length === 1) {
return $tensors[0];
}
var shapes = $tensors.map(function(t) {
return t.shape;
});
assertParamsConsistent(shapes, $axis);
var res = backend2.concat($tensors, $axis);
save($tensors);
return res;
};
var inputs = $tensors;
var attr = {axis};
return ENGINE.runKernelFunc(forward, inputs, null, Concat, attr);
}
var concat = op({concat_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sigmoid_(x) {
var $x = convertToTensor(x, "x", "sigmoid");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.sigmoid($x);
save([res]);
return res;
}, inputs, null, Sigmoid);
}
var sigmoid = op({sigmoid_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function slice_(x, begin, size) {
var $x = convertToTensor(x, "x", "slice");
if ($x.rank === 0) {
throw new Error("Slicing scalar is not possible");
}
var forward = function(backend2, save) {
var _a = parseSliceParams($x, begin, size), begin_ = _a[0], size_ = _a[1];
assertParamsValid($x, begin_, size_);
save([$x]);
return backend2.slice($x, begin_, size_);
};
var inputs = {x: $x};
var attrs = {begin, size};
return ENGINE.runKernelFunc(forward, inputs, null, Slice, attrs);
}
var slice = op({slice_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tanh_(x) {
var $x = convertToTensor(x, "x", "tanh");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var y = backend2.tanh($x);
save([y]);
return y;
}, inputs, null, Tanh);
}
var tanh$1 = op({tanh_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
var $forgetBias = convertToTensor(forgetBias, "forgetBias", "basicLSTMCell");
var $lstmKernel = convertToTensor(lstmKernel, "lstmKernel", "basicLSTMCell");
var $lstmBias = convertToTensor(lstmBias, "lstmBias", "basicLSTMCell");
var $data = convertToTensor(data, "data", "basicLSTMCell");
var $c = convertToTensor(c, "c", "basicLSTMCell");
var $h = convertToTensor(h, "h", "basicLSTMCell");
var combined = concat([$data, $h], 1);
var weighted = matMul(combined, $lstmKernel);
var res = add$1(weighted, $lstmBias);
var batchSize = res.shape[0];
var sliceCols = res.shape[1] / 4;
var sliceSize = [batchSize, sliceCols];
var i = slice(res, [0, 0], sliceSize);
var j = slice(res, [0, sliceCols], sliceSize);
var f = slice(res, [0, sliceCols * 2], sliceSize);
var o = slice(res, [0, sliceCols * 3], sliceSize);
var newC = add$1(mul(sigmoid(i), tanh$1(j)), mul($c, sigmoid(add$1($forgetBias, f))));
var newH = mul(tanh$1(newC), sigmoid(o));
return [newC, newH];
}
var basicLSTMCell = op({basicLSTMCell_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function batchToSpaceND_(x, blockShape, crops) {
var $x = convertToTensor(x, "x", "batchToSpaceND");
var prod2 = blockShape.reduce(function(a, b) {
return a * b;
});
assert($x.rank >= 1 + blockShape.length, function() {
return "input rank is " + $x.rank + " but should be > than blockShape.length " + blockShape.length;
});
assert(crops.length === blockShape.length, function() {
return "crops.length is " + crops.length + " but should be equal to blockShape.length " + blockShape.length;
});
assert($x.shape[0] % prod2 === 0, function() {
return "input tensor batch is " + $x.shape[0] + " but is not divisible by the product of " + ("the elements of blockShape " + blockShape.join(" * ") + " === " + prod2);
});
var forward = function(backend2) {
return backend2.batchToSpaceND($x, blockShape, crops);
};
var inputs = {x: $x};
var attrs = {blockShape, crops};
return ENGINE.runKernelFunc(forward, inputs, null, BatchToSpaceND, attrs);
}
var batchToSpaceND = op({batchToSpaceND_});
function xAs4D(x) {
var x4D;
if (x.rank === 0 || x.rank === 1) {
x4D = reshape(x, [1, 1, 1, x.size]);
} else if (x.rank === 2) {
x4D = reshape(x, [1, 1, x.shape[0], x.shape[1]]);
} else if (x.rank === 3) {
x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
} else {
x4D = x;
}
return x4D;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function batchNorm_(x, mean2, variance, offset, scale, varianceEpsilon) {
if (varianceEpsilon == null) {
varianceEpsilon = 1e-3;
}
var $x = convertToTensor(x, "x", "batchNorm");
var $mean = convertToTensor(mean2, "mean", "batchNorm");
var $variance = convertToTensor(variance, "variance", "batchNorm");
var $scale;
if (scale != null) {
$scale = convertToTensor(scale, "scale", "batchNorm");
}
var $offset;
if (offset != null) {
$offset = convertToTensor(offset, "offset", "batchNorm");
}
assert($mean.rank === $variance.rank, function() {
return "Batch normalization gradient requires mean and variance to have equal ranks.";
});
assert($offset == null || $mean.rank === $offset.rank, function() {
return "Batch normalization gradient requires mean and offset to have equal ranks.";
});
assert($scale == null || $mean.rank === $scale.rank, function() {
return "Batch normalization gradient requires mean and scale to have equal ranks.";
});
var x4D = xAs4D($x);
var forward = function(backend2, save) {
save([x4D, $mean, $variance, $scale]);
return backend2.batchNorm(x4D, as1DOr4D($mean), as1DOr4D($variance), as1DOr4D($offset), as1DOr4D($scale), varianceEpsilon);
};
var inputs = {
x: x4D,
scale: $scale,
offset: $offset,
mean: $mean,
variance: $variance
};
var attrs = {varianceEpsilon};
var res = ENGINE.runKernelFunc(forward, inputs, null, FusedBatchNorm, attrs);
return reshape(res, $x.shape);
}
function as1DOr4D(x) {
if (x == null) {
return null;
}
if (x.rank === 0) {
return reshape(x, [x.size]);
} else if (x.rank === 1) {
return x;
} else if (x.rank === 2) {
return reshape(x, [1, 1, x.shape[0], x.shape[1]]);
} else if (x.rank === 3) {
return reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
}
return x;
}
var batchNorm = op({batchNorm_});
function batchNorm2d_(x, mean2, variance, offset, scale, varianceEpsilon) {
var $x = convertToTensor(x, "x", "batchNorm");
var $mean = convertToTensor(mean2, "mean", "batchNorm");
var $variance = convertToTensor(variance, "variance", "batchNorm");
var $scale;
if (scale != null) {
$scale = convertToTensor(scale, "scale", "batchNorm");
}
var $offset;
if (offset != null) {
$offset = convertToTensor(offset, "offset", "batchNorm");
}
assert($x.rank === 2, function() {
return "Error in batchNorm2D: x must be rank 2 but got rank " + ($x.rank + ".");
});
assert($mean.rank === 2 || $mean.rank === 1, function() {
return "Error in batchNorm2D: mean must be rank 2 or rank 1 but " + ("got rank " + $mean.rank + ".");
});
assert($variance.rank === 2 || $variance.rank === 1, function() {
return "Error in batchNorm2D: variance must be rank 2 or rank 1 " + ("but got rank " + $variance.rank + ".");
});
if ($scale != null) {
assert($scale.rank === 2 || $scale.rank === 1, function() {
return "Error in batchNorm2D: scale must be rank 2 or rank 1 " + ("but got rank " + $scale.rank + ".");
});
}
if ($offset != null) {
assert($offset.rank === 2 || $offset.rank === 1, function() {
return "Error in batchNorm2D: offset must be rank 2 or rank 1 " + ("but got rank " + $offset.rank + ".");
});
}
return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
}
var batchNorm2d = op({batchNorm2d_});
function batchNorm3d_(x, mean2, variance, offset, scale, varianceEpsilon) {
var $x = convertToTensor(x, "x", "batchNorm");
var $mean = convertToTensor(mean2, "mean", "batchNorm");
var $variance = convertToTensor(variance, "variance", "batchNorm");
var $scale;
if (scale != null) {
$scale = convertToTensor(scale, "scale", "batchNorm");
}
var $offset;
if (offset != null) {
$offset = convertToTensor(offset, "offset", "batchNorm");
}
assert($x.rank === 3, function() {
return "Error in batchNorm3D: x must be rank 3 but got rank " + ($x.rank + ".");
});
assert($mean.rank === 3 || $mean.rank === 1, function() {
return "Error in batchNorm3D: mean must be rank 3 or rank 1 but " + ("got rank " + $mean.rank + ".");
});
assert($variance.rank === 3 || $variance.rank === 1, function() {
return "Error in batchNorm3D: variance must be rank 3 or rank 1 " + ("but got rank " + $variance.rank + ".");
});
if ($scale != null) {
assert($scale.rank === 3 || $scale.rank === 1, function() {
return "Error in batchNorm3D: scale must be rank 3 or rank 1 " + ("but got rank " + $scale.rank + ".");
});
}
if ($offset != null) {
assert($offset.rank === 3 || $offset.rank === 1, function() {
return "Error in batchNorm3D: offset must be rank 3 or rank 1 " + ("but got rank " + $offset.rank + ".");
});
}
return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
}
var batchNorm3d = op({batchNorm3d_});
function batchNorm4d_(x, mean2, variance, offset, scale, varianceEpsilon) {
var $x = convertToTensor(x, "x", "batchNorm");
var $mean = convertToTensor(mean2, "mean", "batchNorm");
var $variance = convertToTensor(variance, "variance", "batchNorm");
var $scale;
if (scale != null) {
$scale = convertToTensor(scale, "scale", "batchNorm");
}
var $offset;
if (offset != null) {
$offset = convertToTensor(offset, "offset", "batchNorm");
}
assert($x.rank === 4, function() {
return "Error in batchNorm4D: x must be rank 4 but got rank " + ($x.rank + ".");
});
assert($mean.rank === 4 || $mean.rank === 1, function() {
return "Error in batchNorm4D: mean must be rank 4 or rank 1 but " + ("got rank " + $mean.rank + ".");
});
assert($variance.rank === 4 || $variance.rank === 1, function() {
return "Error in batchNorm4D: variance must be rank 4 or rank 1 " + ("but got rank " + $variance.rank + ".");
});
if ($scale != null) {
assert($scale.rank === 4 || $scale.rank === 1, function() {
return "Error in batchNorm4D: scale must be rank 4 or rank 1 " + ("but got rank " + $scale.rank + ".");
});
}
if ($offset != null) {
assert($offset.rank === 4 || $offset.rank === 1, function() {
return "Error in batchNorm4D: offset must be rank 4 or rank 1 " + ("but got rank " + $offset.rank + ".");
});
}
return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
}
var batchNorm4d = op({batchNorm4d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function broadcastTo_(x, shape) {
var input = convertToTensor(x, "broadcastTo", "x");
var xShape = input.shape;
if (shape.some(function(d) {
return !(d > 0) || d % 1 !== 0;
})) {
throw new Error("broadcastTo(): Invalid broadcast shape [" + shape + "].");
}
if (shape.length < input.rank) {
throw new Error("broadcastTo(): shape.length=" + shape.length + " < input.rank=" + input.rank + ".");
}
if (shape.length > input.rank) {
var newShape = input.shape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = reshape(input, newShape);
}
var inputShape = input.shape;
var reps = Array.from(shape);
for (var i = shape.length - 1; i >= 0; i--) {
if (inputShape[i] === shape[i]) {
reps[i] = 1;
} else if (input.shape[i] !== 1) {
throw new Error("broadcastTo(): [" + xShape + "] cannot be broadcast to [" + shape + "].");
}
}
var axes = reps.map(function(n, i2) {
return n > 1 ? i2 : -1;
}).filter(function(i2) {
return i2 >= 0;
});
if (axes.length === 0) {
return clone(input);
}
var forward = function(backend2) {
return backend2.tile(input, reps);
};
var inputs = {x: input};
var attrs = {shape, inputShape};
return ENGINE.runKernelFunc(forward, inputs, null, BroadcastTo, attrs);
}
var broadcastTo = op({broadcastTo_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function ceil_(x) {
var $x = convertToTensor(x, "x", "ceil");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.ceil($x);
}, inputs, null, Ceil);
}
var ceil = op({ceil_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function clipByValue_(x, clipValueMin, clipValueMax) {
var $x = convertToTensor(x, "x", "clipByValue");
assert(clipValueMin <= clipValueMax, function() {
return "Error in clip: min (" + clipValueMin + ") must be " + ("less than or equal to max (" + clipValueMax + ").");
});
var inputs = {x: $x};
var attrs = {clipValueMin, clipValueMax};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.clip($x, clipValueMin, clipValueMax);
save([$x]);
return res;
}, inputs, null, ClipByValue, attrs);
}
var clipByValue = op({clipByValue_});
function concat1d_(tensors) {
return concat(tensors, 0);
}
var concat1d = op({concat1d_});
function concat2d_(tensors, axis) {
return concat(tensors, axis);
}
var concat2d = op({concat2d_});
function concat3d_(tensors, axis) {
return concat(tensors, axis);
}
var concat3d = op({concat3d_});
function concat4d_(tensors, axis) {
return concat(tensors, axis);
}
var concat4d = op({concat4d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2d_(x, filter, strides, pad2, dataFormat, dilations, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = "NHWC";
}
if (dilations === void 0) {
dilations = [1, 1];
}
var $x = convertToTensor(x, "x", "conv2d");
var $filter = convertToTensor(filter, "filter", "conv2d");
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function() {
return "Error in conv2d: input must be rank 4, but got rank " + x4D.rank + ".";
});
assert($filter.rank === 4, function() {
return "Error in conv2d: filter must be rank 4, but got rank " + ($filter.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in conv2d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var inDepth = dataFormat === "NHWC" ? x4D.shape[3] : x4D.shape[1];
assert(inDepth === $filter.shape[2], function() {
return "Error in conv2d: depth of input (" + inDepth + ") must match " + ("input depth for filter " + $filter.shape[2] + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in conv2D: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var forward = function(backend2, save) {
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad2, dimRoundingMode, false, $dataFormat);
var res2 = backend2.conv2d(x4D, $filter, convInfo);
save([x4D, $filter]);
return res2;
};
var inputs = {x: x4D, filter: $filter};
var attrs = {strides, pad: pad2, dataFormat, dilations, dimRoundingMode};
var res = ENGINE.runKernelFunc(forward, inputs, null, Conv2D, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var conv2d = op({conv2d_});
function conv1d_(x, filter, stride, pad2, dataFormat, dilation, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = "NWC";
}
if (dilation === void 0) {
dilation = 1;
}
var $x = convertToTensor(x, "x", "conv1d");
var $filter = convertToTensor(filter, "filter", "conv1d");
var x3D = $x;
var reshapedTo3D = false;
if ($x.rank === 2) {
reshapedTo3D = true;
x3D = reshape($x, [1, $x.shape[0], $x.shape[1]]);
}
assert(x3D.rank === 3, function() {
return "Error in conv1d: input must be rank 3, but got rank " + x3D.rank + ".";
});
assert($filter.rank === 3, function() {
return "Error in conv1d: filter must be rank 3, but got rank " + ($filter.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in conv1d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
assert(x3D.shape[2] === $filter.shape[1], function() {
return "Error in conv1d: depth of input (" + x3D.shape[2] + ") must match " + ("input depth for filter " + $filter.shape[1] + ".");
});
assert(eitherStridesOrDilationsAreOne(stride, dilation), function() {
return "Error in conv1D: Either stride or dilation must be 1. " + ("Got stride " + stride + " and dilation '" + dilation + "'");
});
assert(dataFormat === "NWC", function() {
return "Error in conv1d: got dataFormat of " + dataFormat + " but only NWC is currently supported.";
});
var filter4D = reshape($filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]);
var input4D = reshape(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]);
var strides = [1, stride];
var dilations = [1, dilation];
var conv2dDataFormat = "NHWC";
var res = conv2d(input4D, filter4D, strides, pad2, conv2dDataFormat, dilations, dimRoundingMode);
if (reshapedTo3D) {
return reshape(res, [res.shape[2], res.shape[3]]);
}
return reshape(res, [res.shape[0], res.shape[2], res.shape[3]]);
}
var conv1d = op({conv1d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2DBackpropInput_(xShape, dy, filter, strides, pad2, dataFormat, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = "NHWC";
}
assert(xShape.length === dy.rank, function() {
return "Length of inShape " + ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match");
});
var xShape4D = xShape;
var dy4D = dy;
var reshapedTo4D = false;
if (dy.rank === 3) {
reshapedTo4D = true;
dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
xShape4D = [1, xShape[0], xShape[1], xShape[2]];
}
assert(xShape4D.length === 4, function() {
return "Error in conv2dDerInput: inShape must be length 4, but got length " + (xShape4D.length + ".");
});
assert(dy4D.rank === 4, function() {
return "Error in conv2dDerInput: dy must be rank 4, but got " + ("rank " + dy4D.rank);
});
assert(filter.rank === 4, function() {
return "Error in conv2dDerInput: filter must be rank 4, but got " + ("rank " + filter.rank);
});
var inDepth = dataFormat === "NHWC" ? xShape4D[3] : xShape4D[1];
var outDepth = dataFormat === "NHWC" ? dy4D.shape[3] : dy4D.shape[1];
assert(inDepth === filter.shape[2], function() {
return "Error in conv2dDerInput: depth of input (" + inDepth + ") must " + ("match input depth for filter " + filter.shape[2] + ".");
});
assert(outDepth === filter.shape[3], function() {
return "Error in conv2dDerInput: depth of output (" + outDepth + ") must " + ("match output depth for filter " + filter.shape[3] + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in conv2dDerInput: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2, save) {
var dilations = 1;
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(xShape4D, filter.shape, strides, dilations, pad2, dimRoundingMode, false, $dataFormat);
var res2 = backend2.conv2dDerInput(dy4D, filter, convInfo);
save([dy4D, filter]);
return res2;
};
var inputs = {dy: dy4D, filter};
var attrs = {strides, pad: pad2, dataFormat, dimRoundingMode, inputShape: xShape4D};
var res = ENGINE.runKernelFunc(forward, inputs, null, Conv2DBackpropInput, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var conv2DBackpropInput = op({conv2DBackpropInput_});
function conv2dTranspose_(x, filter, outputShape, strides, pad2, dimRoundingMode) {
var $x = convertToTensor(x, "x", "conv2dTranspose");
var $filter = convertToTensor(filter, "filter", "conv2dTranspose");
return conv2DBackpropInput(outputShape, $x, $filter, strides, pad2, "NHWC", dimRoundingMode);
}
var conv2dTranspose = op({conv2dTranspose_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3d_(x, filter, strides, pad2, dataFormat, dilations) {
if (dataFormat === void 0) {
dataFormat = "NDHWC";
}
if (dilations === void 0) {
dilations = [1, 1, 1];
}
var $x = convertToTensor(x, "x", "conv3d");
var $filter = convertToTensor(filter, "filter", "conv3d");
var x5D = $x;
var reshapedTo5D = false;
if ($x.rank === 4) {
reshapedTo5D = true;
x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
}
assert(x5D.rank === 5, function() {
return "Error in conv3d: input must be rank 5, but got rank " + x5D.rank + ".";
});
assert($filter.rank === 5, function() {
return "Error in conv3d: filter must be rank 5, but got rank " + ($filter.rank + ".");
});
assert(x5D.shape[4] === $filter.shape[3], function() {
return "Error in conv3d: depth of input (" + x5D.shape[4] + ") must match " + ("input depth for filter " + $filter.shape[3] + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in conv3D: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
assert(dataFormat === "NDHWC", function() {
return "Error in conv3d: got dataFormat of " + dataFormat + " but only NDHWC is currently supported.";
});
var forward = function(backend2, save) {
var convInfo = computeConv3DInfo(x5D.shape, $filter.shape, strides, dilations, pad2);
var res2 = backend2.conv3d(x5D, $filter, convInfo);
save([x5D, $filter]);
return res2;
};
var inputs = {x: x5D, filter: $filter};
var attrs = {strides, pad: pad2, dataFormat, dilations};
var res = ENGINE.runKernelFunc(forward, inputs, null, Conv3D, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var conv3d = op({conv3d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3DBackpropInput_(xShape, dy, filter, strides, pad2) {
assert(xShape.length === dy.rank, function() {
return "Length of inShape " + ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match");
});
var xShape5D = xShape;
var dy5D = dy;
var reshapedTo5D = false;
if (dy.rank === 4) {
reshapedTo5D = true;
dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
}
var inDepth = xShape5D[4];
var outDepth = dy5D.shape[4];
assert(xShape5D.length === 5, function() {
return "Error in conv3dDerInput: inShape must be length 5, but got length " + (xShape5D.length + ".");
});
assert(dy5D.rank === 5, function() {
return "Error in conv3dDerInput: dy must be rank 5, but got " + ("rank " + dy5D.rank);
});
assert(filter.rank === 5, function() {
return "Error in conv3dDerInput: filter must be rank 5, but got " + ("rank " + filter.rank);
});
assert(inDepth === filter.shape[3], function() {
return "Error in conv3dDerInput: depth of input (" + inDepth + ") must " + ("match input depth for filter " + filter.shape[3] + ".");
});
assert(outDepth === filter.shape[4], function() {
return "Error in conv3dDerInput: depth of output (" + outDepth + ") must " + ("match output depth for filter " + filter.shape[4] + ".");
});
var forward = function(backend2) {
var dilations = 1;
var convInfo = computeConv3DInfo(xShape5D, filter.shape, strides, dilations, pad2);
return backend2.conv3dDerInput(dy5D, filter, convInfo);
};
var inputs = {dy: dy5D, filter};
var attrs = {pad: pad2, strides, inputShape: xShape5D};
var res = ENGINE.runKernelFunc(forward, inputs, null, Conv3DBackpropInputV2, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var conv3DBackpropInput = op({conv3DBackpropInput_});
function conv3dTranspose_(x, filter, outputShape, strides, pad2) {
var $x = convertToTensor(x, "x", "conv3dTranspose");
var $filter = convertToTensor(filter, "filter", "conv3dTranspose");
return conv3DBackpropInput(outputShape, $x, $filter, strides, pad2);
}
var conv3dTranspose = op({conv3dTranspose_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cos_(x) {
var $x = convertToTensor(x, "x", "cos");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.cos($x);
save([$x]);
return res;
}, inputs, null, Cos);
}
var cos = op({cos_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cosh_(x) {
var $x = convertToTensor(x, "x", "cosh");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.cosh($x);
save([$x]);
return res;
}, inputs, null, Cosh);
}
var cosh = op({cosh_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cumsum_(x, axis, exclusive, reverse2) {
if (axis === void 0) {
axis = 0;
}
if (exclusive === void 0) {
exclusive = false;
}
if (reverse2 === void 0) {
reverse2 = false;
}
var $x = convertToTensor(x, "x", "cumsum");
var forward = function(backend2, save) {
var permutation = getAxesPermutation([axis], $x.rank);
var permutedX = $x;
if (permutation != null) {
permutedX = transpose($x, permutation);
}
var permutedAxis = getInnerMostAxes(1, $x.rank)[0];
var value = backend2.cumsum(permutedX, permutedAxis, exclusive, reverse2);
save([$x]);
if (permutation != null) {
var reversePermutation = getUndoAxesPermutation(permutation);
value = transpose(value, reversePermutation);
}
return value;
};
var inputs = {x: $x};
var attrs = {axis, exclusive, reverse: reverse2};
return ENGINE.runKernelFunc(forward, inputs, null, Cumsum, attrs);
}
var cumsum = op({cumsum_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthToSpace_(x, blockSize, dataFormat) {
if (dataFormat === void 0) {
dataFormat = "NHWC";
}
var $x = convertToTensor(x, "x", "depthToSpace");
var inputHeight = dataFormat === "NHWC" ? $x.shape[1] : $x.shape[2];
var inputWidth = dataFormat === "NHWC" ? $x.shape[2] : $x.shape[3];
var inputDepth = dataFormat === "NHWC" ? $x.shape[3] : $x.shape[1];
assert(inputHeight * blockSize >= 0, function() {
return "Negative dimension size caused by overflow when multiplying\n " + inputHeight + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape;
});
assert(inputWidth * blockSize >= 0, function() {
return "Negative dimension size caused by overflow when multiplying\n " + inputWidth + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape;
});
assert(inputDepth % (blockSize * blockSize) === 0, function() {
return "Dimension size must be evenly divisible by " + blockSize * blockSize + " but is " + inputDepth + " for depthToSpace with input shape " + $x.shape;
});
var forward = function(backend2) {
return backend2.depthToSpace($x, blockSize, dataFormat);
};
var inputs = {x: $x};
var attrs = {blockSize, dataFormat};
return ENGINE.runKernelFunc(forward, inputs, null, DepthToSpace, attrs);
}
var depthToSpace = op({depthToSpace_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2d_(x, filter, strides, pad2, dataFormat, dilations, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = "NHWC";
}
if (dilations === void 0) {
dilations = [1, 1];
}
var $x = convertToTensor(x, "x", "depthwiseConv2d");
var $filter = convertToTensor(filter, "filter", "depthwiseConv2d");
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function() {
return "Error in depthwiseConv2d: input must be rank 4, but got " + ("rank " + x4D.rank + ".");
});
assert($filter.rank === 4, function() {
return "Error in depthwiseConv2d: filter must be rank 4, but got rank " + ($filter.rank + ".");
});
assert(x4D.shape[3] === $filter.shape[2], function() {
return "Error in depthwiseConv2d: number of input channels " + ("(" + x4D.shape[3] + ") must match the inChannels dimension in ") + ("filter " + $filter.shape[2] + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in depthwiseConv2d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2, save) {
if (dilations == null) {
dilations = [1, 1];
}
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in depthwiseConv2d: Either strides or dilations must be " + ("1. Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad2, dimRoundingMode, true);
var res2 = backend2.depthwiseConv2D(x4D, $filter, convInfo);
save([x4D, $filter]);
return res2;
};
var inputs = {x: x4D, filter: $filter};
var attrs = {strides, pad: pad2, dataFormat, dilations, dimRoundingMode};
var res = ENGINE.runKernelFunc(forward, inputs, null, DepthwiseConv2dNative, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var depthwiseConv2d = op({depthwiseConv2d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function diag_(x) {
var $x = convertToTensor(x, "x", "diag");
var forward = function(backend2) {
var flat = reshape($x, [$x.size]);
var result = backend2.diag(flat);
var outShape = x.shape.concat(x.shape);
return reshape(result, outShape);
};
var inputs = {x: $x};
return ENGINE.runKernelFunc(forward, inputs, null, Diag);
}
var diag = op({diag_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function dilation2d_(x, filter, strides, pad2, dilations, dataFormat) {
if (dilations === void 0) {
dilations = [1, 1];
}
if (dataFormat === void 0) {
dataFormat = "NHWC";
}
var $x = convertToTensor(x, "x", "dilation2d");
var $filter = convertToTensor(filter, "filter", "dilation2d");
assert($x.rank === 3 || $x.rank === 4, function() {
return "Error in dilation2d: input must be rank 3 or 4, but got rank " + ($x.rank + ".");
});
assert($filter.rank === 3, function() {
return "Error in dilation2d: filter must be rank 3, but got rank " + ($filter.rank + ".");
});
assert(dataFormat === "NHWC", function() {
return "Error in dilation2d: Only NHWC is currently supported, " + ("but got dataFormat of " + dataFormat);
});
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
reshapedTo4D = true;
}
var inputs = {x: x4D, filter: $filter};
var attrs = {strides, pad: pad2, dilations};
var res = ENGINE.runKernel(Dilation2D, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var dilation2d = op({dilation2d_});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getBroadcastDims(inShape, outShape) {
var inRank = inShape.length;
var dims = [];
for (var i = 0; i < inRank; i++) {
var dim = inRank - 1 - i;
var a = inShape[dim] || 1;
var b = outShape[outShape.length - 1 - i] || 1;
if (b > 1 && a === 1) {
dims.unshift(dim);
}
}
return dims;
}
function getReductionAxes(inShape, outShape) {
var result = [];
for (var i = 0; i < outShape.length; i++) {
var inDim = inShape[inShape.length - i - 1];
var outAxis = outShape.length - i - 1;
var outDim = outShape[outAxis];
if (inDim == null || inDim === 1 && outDim > 1) {
result.unshift(outAxis);
}
}
return result;
}
function assertAndGetBroadcastShape(shapeA, shapeB) {
var result = [];
var l = Math.max(shapeA.length, shapeB.length);
for (var i = 0; i < l; i++) {
var a = shapeA[shapeA.length - i - 1];
if (a == null) {
a = 1;
}
var b = shapeB[shapeB.length - i - 1];
if (b == null) {
b = 1;
}
if (a === 1) {
result.unshift(b);
} else if (b === 1) {
result.unshift(a);
} else if (a !== b) {
var errMsg = "Operands could not be broadcast together with shapes " + (shapeA + " and " + shapeB + ".");
throw Error(errMsg);
} else {
result.unshift(a);
}
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function equal_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "equal");
var $b = convertToTensor(b, "b", "equal");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2) {
return backend2.equal($a, $b);
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Equal);
}
var equal = op({equal_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function where_(condition, a, b) {
var $a = convertToTensor(a, "a", "where");
var $b = convertToTensor(b, "b", "where");
var $condition = convertToTensor(condition, "condition", "where", "bool");
var broadcastShape = assertAndGetBroadcastShape($a.shape, $b.shape);
var $broadcastedA = broadcastTo($a, broadcastShape);
var $broadcastedB = broadcastTo($b, broadcastShape);
if ($condition.rank === 1) {
assert($condition.shape[0] === $a.shape[0], function() {
return "The first dimension of `a` must match the size of `condition`.";
});
}
if ($condition.rank !== 1) {
assertShapesMatch($condition.shape, $broadcastedB.shape, "Error in where: ");
}
var forward = function(backend2, save) {
var res = backend2.select($condition, $broadcastedA, $broadcastedB);
save([$condition]);
return res;
};
var inputs = {
condition: $condition,
t: $broadcastedA,
e: $broadcastedB
};
return ENGINE.runKernelFunc(forward, inputs, null, SelectV2);
}
var where = op({where_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function zerosLike_(x) {
var $x = convertToTensor(x, "x", "zerosLike");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.zerosLike($x);
}, inputs, null, ZerosLike);
}
var zerosLike = op({zerosLike_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function divNoNan_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "div");
var $b = convertToTensor(b, "b", "div");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var divResult = div($a, $b);
var zeros2 = zerosLike(divResult);
var bEqualsZero = equal($b, zeros2);
return where(bEqualsZero, zeros2, divResult);
}
var divNoNan = op({divNoNan_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function dot_(t1, t2) {
var $t1 = convertToTensor(t1, "t1", "dot");
var $t2 = convertToTensor(t2, "t2", "dot");
assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), function() {
return "Error in dot: inputs must all be rank 1 or 2, but got ranks " + ($t1.rank + " and " + $t2.rank + ".");
});
var t1Inner = $t1.rank === 1 ? $t1.size : $t1.shape[1];
var t2Inner = $t2.rank === 1 ? $t2.size : $t2.shape[0];
assert(t1Inner === t2Inner, function() {
return "Error in dot: inner dimensions of inputs must match, but got " + (t1Inner + " and " + t2Inner + ".");
});
if ($t1.rank === 1 && $t2.rank === 1) {
var t12D = reshape($t1, [1, -1]);
var t22D = reshape($t2, [-1, 1]);
var t1t2 = matMul(t12D, t22D);
return reshape(t1t2, []);
} else if ($t1.rank === 1 && $t2.rank === 2) {
var t12D = reshape($t1, [1, -1]);
var t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
var t1t2 = matMul(t12D, t22D);
return reshape(t1t2, [t1t2.size]);
} else if ($t1.rank === 2 && $t2.rank === 1) {
var t22D = reshape($t2, [-1, 1]);
var t1t2 = matMul($t1, t22D);
return reshape(t1t2, [t1t2.size]);
} else {
var t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
var t1t2 = matMul($t1, t22D);
return t1t2;
}
}
var dot = op({dot_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function elu_(x) {
var $x = convertToTensor(x, "x", "elu");
var forward = function(backend2, save) {
var y = backend2.elu($x);
save([y]);
return y;
};
var inputs = {x: $x};
return ENGINE.runKernelFunc(forward, inputs, null, Elu);
}
var elu = op({elu_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function erf_(x) {
var $x = convertToTensor(x, "x", "erf");
assert($x.dtype === "int32" || $x.dtype === "float32", function() {
return "Input dtype must be `int32` or `float32`.";
});
if ($x.dtype === "int32") {
$x = cast($x, "float32");
}
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.erf($x);
save([$x]);
return res;
}, inputs, null, Erf);
}
var erf = op({erf_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function exp_(x) {
var $x = convertToTensor(x, "x", "exp");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.exp($x);
save([res]);
return res;
}, inputs, null, Exp);
}
var exp = op({exp_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function expandDims_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var parseAs = null;
var $x = convertToTensor(x, "x", "expandDims", parseAs);
assert(axis <= $x.rank, function() {
return "Axis must be <= rank of the tensor";
});
var newShape = $x.shape.slice();
if (axis < 0) {
assert(-($x.rank + 1) <= axis, function() {
return "Axis must be in the interval [" + -($x.rank + 1) + ", " + $x.rank + "]";
});
axis = $x.rank + axis + 1;
}
newShape.splice(axis, 0, 1);
return reshape($x, newShape);
}
var expandDims = op({expandDims_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function expm1_(x) {
var $x = convertToTensor(x, "x", "expm1");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.expm1($x);
save([$x]);
return res;
}, inputs, null, Expm1);
}
var expm1 = op({expm1_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tile_(x, reps) {
var parseAs = null;
var $x = convertToTensor(x, "x", "tile", parseAs);
assert($x.rank === reps.length, function() {
return "Error in transpose: rank of input " + $x.rank + " " + ("must match length of reps " + reps + ".");
});
var forward = function(backend2, save) {
var res = backend2.tile($x, reps);
save([$x]);
return res;
};
var inputsToSave = [$x];
var inputs = {x: $x};
var attrs = {reps};
return ENGINE.runKernelFunc(forward, inputs, null, Tile, attrs, inputsToSave);
}
var tile = op({tile_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function eye_(numRows, numColumns, batchShape, dtype) {
if (dtype === void 0) {
dtype = "float32";
}
if (numColumns == null) {
numColumns = numRows;
}
var buff = buffer([numRows, numColumns], dtype);
var n = numRows <= numColumns ? numRows : numColumns;
for (var i = 0; i < n; ++i) {
buff.set(1, i, i);
}
var out = reshape(buff.toTensor(), [numRows, numColumns]);
if (batchShape == null) {
return out;
} else {
if (batchShape.length === 1) {
return tile(expandDims(out, 0), [batchShape[0], 1, 1]);
} else if (batchShape.length === 2) {
return tile(expandDims(expandDims(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
} else if (batchShape.length === 3) {
return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [
batchShape[0],
batchShape[1],
batchShape[2],
1,
1
]);
} else {
throw new Error("eye() currently supports only 1D and 2D " + ("batchShapes, but received " + batchShape.length + "D."));
}
}
}
var eye = op({eye_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fill(shape, value, dtype) {
var attrs = {shape, value, dtype};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.fill(shape, value, dtype);
}, {}, null, Fill, attrs);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function floor_(x) {
var $x = convertToTensor(x, "x", "floor");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.floor($x);
}, inputs, null, Floor);
}
var floor = op({floor_});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PARALLELIZE_THRESHOLD = 30;
function computeOptimalWindowSize(inSize) {
if (inSize <= PARALLELIZE_THRESHOLD) {
return inSize;
}
return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function segOpComputeOptimalWindowSize(inSize, numSegments) {
var done = false;
var res;
if (inSize <= PARALLELIZE_THRESHOLD) {
res = inSize;
done = true;
} else {
res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
}
while (!done) {
if (res > numSegments || res === inSize) {
done = true;
} else {
res = nearestDivisor(inSize, res + 1);
}
}
return res;
}
function computeOutShape$2(aShape, axis, numSegments) {
var outShape = [];
var rank = aShape.length;
for (var dim = 0; dim < rank; dim++) {
if (dim !== axis) {
outShape.push(aShape[dim]);
} else {
outShape.push(numSegments);
}
}
return outShape;
}
function collectGatherOpShapeInfo(x, indices, axis) {
var dimSize = x.shape[axis];
var outputShape = [];
var batchSize = 1;
var sliceSize = 1;
for (var i = 0; i < axis; i++) {
outputShape.push(x.shape[i]);
batchSize *= x.shape[i];
}
for (var i = 0; i < indices.rank; i++) {
outputShape.push(indices.shape[i]);
}
for (var i = axis + 1; i < x.rank; i++) {
outputShape.push(x.shape[i]);
sliceSize *= x.shape[i];
}
return {batchSize, sliceSize, dimSize, outputShape};
}
var segment_util = {
__proto__: null,
segOpComputeOptimalWindowSize,
computeOutShape: computeOutShape$2,
collectGatherOpShapeInfo
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gather_(x, indices, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, "x", "gather");
var $indices = convertToTensor(indices, "indices", "gather", "int32");
var inputs = {x: $x, indices: $indices};
var attrs = {axis};
var forward = function(backend2, save) {
var parsedAxis = parseAxisParam(axis, $x.shape)[0];
var shapeInfo = collectGatherOpShapeInfo($x, $indices, parsedAxis);
var res = backend2.gather($x, reshape($indices, [$indices.size]), parsedAxis);
save([$x, $indices]);
return reshape(res, shapeInfo.outputShape);
};
return ENGINE.runKernelFunc(forward, inputs, null, GatherV2, attrs);
}
var gather = op({gather_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function greater_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "greater");
var $b = convertToTensor(b, "b", "greater");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2) {
return backend2.greater($a, $b);
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Greater);
}
var greater = op({greater_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function greaterEqual_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "greaterEqual");
var $b = convertToTensor(b, "b", "greaterEqual");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2, save) {
var res = backend2.greaterEqual($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, GreaterEqual);
}
var greaterEqual = op({greaterEqual_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function imag_(input) {
var $input = convertToTensor(input, "input", "imag");
var forward = function(backend2) {
return backend2.imag($input);
};
var inputs = {input: $input};
return ENGINE.runKernelFunc(forward, inputs, null, Imag);
}
var imag = op({imag_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function isFinite_(x) {
var $x = convertToTensor(x, "x", "isFinite");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.isFinite($x);
}, inputs, null, IsFinite);
}
var isFinite$1 = op({isFinite_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function isInf_(x) {
var $x = convertToTensor(x, "x", "isInf");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.isInf($x);
}, inputs, null, IsInf);
}
var isInf = op({isInf_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function isNaN_(x) {
var $x = convertToTensor(x, "x", "isNaN");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.isNaN($x);
}, inputs, null, IsNan);
}
var isNaN$1 = op({isNaN_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maximum_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "maximum");
var $b = convertToTensor(b, "b", "maximum");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
if ($a.dtype === "bool") {
$a = cast($a, "int32");
$b = cast($b, "int32");
}
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2, save) {
var res = backend2.maximum($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Maximum);
}
var maximum = op({maximum_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function scalar(value, dtype) {
if ((isTypedArray(value) && dtype !== "string" || Array.isArray(value)) && dtype !== "complex64") {
throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");
}
if (dtype === "string" && isTypedArray(value) && !(value instanceof Uint8Array)) {
throw new Error("When making a scalar from encoded string, the value must be `Uint8Array`.");
}
var shape = [];
var inferredShape = [];
return makeTensor(value, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function leakyRelu_(x, alpha) {
if (alpha === void 0) {
alpha = 0.2;
}
var $x = convertToTensor(x, "x", "leakyRelu");
return maximum(mul(scalar(alpha), $x), $x);
}
var leakyRelu = op({leakyRelu_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function less_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "less");
var $b = convertToTensor(b, "b", "less");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2) {
return backend2.less($a, $b);
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Less);
}
var less = op({less_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function lessEqual_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "lessEqual");
var $b = convertToTensor(b, "b", "lessEqual");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2, save) {
var res = backend2.lessEqual($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, LessEqual);
}
var lessEqual = op({lessEqual_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function linspace(start, stop, num) {
if (num <= 0) {
throw new Error("The number of values should be positive.");
}
var attrs = {start, stop, num};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.linspace(start, stop, num);
}, {}, null, LinSpace, attrs);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function localResponseNormalization_(x, depthRadius, bias, alpha, beta) {
if (depthRadius === void 0) {
depthRadius = 5;
}
if (bias === void 0) {
bias = 1;
}
if (alpha === void 0) {
alpha = 1;
}
if (beta === void 0) {
beta = 0.5;
}
var $x = convertToTensor(x, "x", "localResponseNormalization");
assert($x.rank === 4 || $x.rank === 3, function() {
return "Error in localResponseNormalization: x must be rank 3 or 4 but got\n rank " + $x.rank + ".";
});
assert(isInt(depthRadius), function() {
return "Error in localResponseNormalization: depthRadius must be an " + ("integer but got depthRadius " + depthRadius + ".");
});
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
var forward = function(backend2, save) {
var y = backend2.localResponseNormalization4D(x4D, depthRadius, bias, alpha, beta);
save([x4D, y]);
return y;
};
var inputs = {x: x4D};
var attrs = {depthRadius, bias, alpha, beta};
var res = ENGINE.runKernelFunc(forward, inputs, null, LRN, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
} else {
return res;
}
}
var localResponseNormalization = op({localResponseNormalization_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function log_(x) {
var $x = convertToTensor(x, "x", "log");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.log($x);
save([$x]);
return res;
}, inputs, null, Log);
}
var log = op({log_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function log1p_(x) {
var $x = convertToTensor(x, "x", "log1p");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.log1p($x);
save([$x]);
return res;
}, inputs, null, Log1p);
}
var log1p = op({log1p_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function grad(f) {
assert(isFunction(f), function() {
return "The f passed in grad(f) must be a function";
});
return function(x, dy) {
var $x = convertToTensor(x, "x", "tf.grad", null);
var $dy = dy != null ? convertToTensor(dy, "dy", "tf.grad") : null;
return ENGINE.tidy(function() {
var _a = ENGINE.gradients(function() {
return f($x);
}, [$x], $dy), value = _a.value, grads2 = _a.grads;
if ($dy != null) {
assertShapesMatch(value.shape, $dy.shape, "The shape of dy passed in grad(f)(x, dy) must match the shape returned by f(x)");
}
checkGrads(grads2);
return grads2[0];
});
};
}
function grads(f) {
assert(isFunction(f), function() {
return "The f passed in grads(f) must be a function";
});
return function(args, dy) {
assert(Array.isArray(args), function() {
return "The args passed in grads(f)(args) must be an array of `Tensor`s or `TensorLike`s";
});
var $args = convertToTensorArray(args, "args", "tf.grads", null);
var $dy = dy != null ? convertToTensor(dy, "dy", "tf.grads") : null;
return ENGINE.tidy(function() {
var _a = ENGINE.gradients(function() {
return f.apply(void 0, $args);
}, $args, $dy), value = _a.value, grads2 = _a.grads;
if ($dy != null) {
assertShapesMatch(value.shape, $dy.shape, "The shape of dy passed in grads(f)([x1,...], dy) must match the shape returned by f([x1,...])");
}
checkGrads(grads2);
return grads2;
});
};
}
function valueAndGrad(f) {
assert(isFunction(f), function() {
return "The f passed in valueAndGrad(f) must be a function";
});
return function(x, dy) {
assert(x instanceof Tensor, function() {
return "The x passed in valueAndGrad(f)(x) must be a tensor";
});
assert(dy == null || dy instanceof Tensor, function() {
return "The dy passed in valueAndGrad(f)(x, dy) must be a tensor";
});
var _a = ENGINE.gradients(function() {
return f(x);
}, [x], dy), grads2 = _a.grads, value = _a.value;
checkGrads(grads2);
return {grad: grads2[0], value};
};
}
function valueAndGrads(f) {
assert(isFunction(f), function() {
return "The f passed in valueAndGrads(f) must be a function";
});
return function(args, dy) {
assert(Array.isArray(args) && args.every(function(arg) {
return arg instanceof Tensor;
}), function() {
return "The args passed in valueAndGrads(f)(args) must be array of tensors";
});
assert(dy == null || dy instanceof Tensor, function() {
return "The dy passed in valueAndGrads(f)(args, dy) must be a tensor";
});
var res = ENGINE.gradients(function() {
return f.apply(void 0, args);
}, args, dy);
if (dy != null) {
assertShapesMatch(res.value.shape, dy.shape, "The shape of dy passed in valueAndGrads(f)([x1,...], dy) must match the shape returned by f([x1,...])");
}
checkGrads(res.grads);
return res;
};
}
function variableGrads(f, varList) {
assert(isFunction(f), function() {
return "The f passed in variableGrads(f) must be a function";
});
assert(varList == null || Array.isArray(varList) && varList.every(function(v) {
return v instanceof Variable;
}), function() {
return "The varList passed in variableGrads(f, varList) must be an array of variables";
});
var specifiedVarList = varList != null;
if (!specifiedVarList) {
varList = [];
for (var varName in ENGINE.registeredVariables) {
varList.push(ENGINE.registeredVariables[varName]);
}
}
var specifiedNonTrainable = specifiedVarList ? varList.filter(function(variable2) {
return !variable2.trainable;
}) : null;
var originalVarCount = varList.length;
varList = varList.filter(function(variable2) {
return variable2.trainable;
});
assert(varList.length > 0, function() {
return "variableGrads() expects at least one of the input variables to " + ("be trainable, but none of the " + originalVarCount + " variables is ") + "trainable.";
});
var allowNoGradients = true;
var _a = ENGINE.gradients(f, varList, null, allowNoGradients), value = _a.value, grads2 = _a.grads;
assert(grads2.some(function(g) {
return g != null;
}), function() {
return "Cannot find a connection between any variable and the result of the loss function y=f(x). Please make sure the operations that use variables are inside the function f passed to minimize().";
});
assert(value.rank === 0, function() {
return "The f passed in variableGrads(f) must return a scalar, but it " + ("returned a rank-" + value.rank + " tensor");
});
var namedGrads = {};
varList.forEach(function(v, i) {
if (grads2[i] != null) {
namedGrads[v.name] = grads2[i];
}
});
if (specifiedNonTrainable != null) {
specifiedNonTrainable.forEach(function(v) {
return namedGrads[v.name] = null;
});
}
return {value, grads: namedGrads};
}
function customGrad(f) {
return ENGINE.customGrad(f);
}
function checkGrads(grads2) {
var numNullGradients = grads2.filter(function(g) {
return g == null;
}).length;
if (numNullGradients > 0) {
throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y.");
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function neg_(x) {
var $x = convertToTensor(x, "x", "neg");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.neg($x);
}, inputs, null, Negate);
}
var neg = op({neg_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function softplus_(x) {
var $x = convertToTensor(x, "x", "softplus");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.softplus($x);
save([$x]);
return res;
}, inputs, null, Softplus);
}
var softplus = op({softplus_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function logSigmoid_(x) {
var $x = convertToTensor(x, "x", "logSigmoid");
var customOp = customGrad(function(x2) {
var value = neg(softplus(neg(x2)));
var gradFunc = function(dy) {
var derX = mul(dy, sigmoid(neg(x2)));
return derX;
};
return {value, gradFunc};
});
return customOp($x);
}
var logSigmoid = op({logSigmoid_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function max_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, "x", "max");
var forward = function(backend2, save) {
var origAxes = parseAxisParam(axis, $x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, $x.rank);
var maxInput = $x;
if (permutedAxes != null) {
maxInput = transpose($x, permutedAxes);
axes = getInnerMostAxes(axes.length, maxInput.rank);
}
var y = backend2.max(maxInput, axes);
if (permutedAxes != null) {
maxInput.dispose();
}
var res = y;
if (keepDims) {
var expandedShape = expandShapeToKeepDim(res.shape, parseAxisParam(axis, $x.shape));
res = reshape(res, expandedShape);
y.dispose();
}
save([$x, res]);
return res;
};
var inputs = {x: $x};
var attrs = {reductionIndices: axis, keepDims};
return ENGINE.runKernelFunc(forward, inputs, null, Max, attrs);
}
var max = op({max_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sub_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "sub");
var $b = convertToTensor(b, "b", "sub");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forward = function(backend2, save) {
var res = backend2.subtract($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Sub);
}
var sub = op({sub_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sum_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, "x", "sum");
if ($x.dtype === "bool") {
$x = cast($x, "int32");
}
var forward = function(backend2, save) {
save([$x]);
var axes = parseAxisParam(axis, $x.shape);
var permutation = getAxesPermutation(axes, $x.rank);
var reductionAxes = axes;
var permutedX = $x;
if (permutation != null) {
permutedX = transpose($x, permutation);
reductionAxes = getInnerMostAxes(reductionAxes.length, $x.rank);
}
var value = backend2.sum(permutedX, reductionAxes);
if (keepDims) {
var newShape = expandShapeToKeepDim(value.shape, axes);
value = reshape(value, newShape);
}
return value;
};
var inputs = {x: $x};
var attrs = {axis, keepDims};
return ENGINE.runKernelFunc(forward, inputs, null, Sum, attrs);
}
var sum$1 = op({sum_});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function logSoftmax_(logits, axis) {
if (axis === void 0) {
axis = -1;
}
var $logits = convertToTensor(logits, "logits", "logSoftmax");
if (axis === -1) {
axis = $logits.rank - 1;
}
if (axis !== $logits.rank - 1) {
throw Error("Log Softmax along a non-last dimension is not yet supported. " + ("Logits was rank " + $logits.rank + " and axis was " + axis));
}
var forward = function(backend2, save) {
var keepDims = true;
var xMax = max(logits, axis, true);
var shifted = sub(logits, xMax);
var value = sub(cast(shifted, "float32"), log(sum$1(exp(shifted), axis, keepDims)));
save([value]);
return value;
};
var inputs = {logits: $logits};
var attrs = {axis};
return ENGINE.runKernelFunc(forward, inputs, null, LogSoftmax, attrs);
}
var logSoftmax = op({logSoftmax_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function logSumExp_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, "x", "logSumExp");
var axes = parseAxisParam(axis, $x.shape);
var xMax = max($x, axes, true);
var a = sub($x, xMax);
var b = exp(a);
var c = sum$1(b, axes);
var d = log(c);
var res = add$1(reshape(xMax, d.shape), d);
if (keepDims) {
var newShape = expandShapeToKeepDim(res.shape, axes);
return reshape(res, newShape);
}
return res;
}
var logSumExp = op({logSumExp_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function logicalAnd_(a, b) {
var $a = convertToTensor(a, "a", "logicalAnd", "bool");
var $b = convertToTensor(b, "b", "logicalAnd", "bool");
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.logicalAnd($a, $b);
}, inputs, null, LogicalAnd);
}
var logicalAnd = op({logicalAnd_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function logicalNot_(x) {
var $x = convertToTensor(x, "x", "logicalNot", "bool");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.logicalNot($x);
}, inputs, null, LogicalNot);
}
var logicalNot = op({logicalNot_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function logicalOr_(a, b) {
var $a = convertToTensor(a, "a", "logicalOr", "bool");
var $b = convertToTensor(b, "b", "logicalOr", "bool");
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.logicalOr($a, $b);
}, inputs, null, LogicalOr);
}
var logicalOr = op({logicalOr_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function logicalXor_(a, b) {
var $a = convertToTensor(a, "a", "logicalXor", "bool");
var $b = convertToTensor(b, "b", "logicalXor", "bool");
assertAndGetBroadcastShape($a.shape, $b.shape);
return logicalAnd(logicalOr(a, b), logicalNot(logicalAnd(a, b)));
}
var logicalXor = op({logicalXor_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool_(x, filterSize, strides, pad2, dimRoundingMode) {
var $x = convertToTensor(x, "x", "maxPool");
var dilations = 1;
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function() {
return "Error in maxPool: input must be rank 4 but got rank " + x4D.rank + ".";
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in maxPool: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in maxPool: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2, save) {
var convInfo = computePool2DInfo(x4D.shape, filterSize, strides, 1, pad2, dimRoundingMode);
var y;
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
y = x4D.clone();
} else {
y = backend2.maxPool(x4D, convInfo);
}
save([x4D, y]);
return y;
};
var inputs = {x: x4D};
var attrs = {filterSize, strides, pad: pad2, dimRoundingMode};
var res = ENGINE.runKernelFunc(forward, inputs, null, MaxPool, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var maxPool = op({maxPool_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool3d_(x, filterSize, strides, pad2, dimRoundingMode, dataFormat, dilations) {
if (filterSize === void 0) {
filterSize = [1, 1, 1];
}
if (dataFormat === void 0) {
dataFormat = "NDHWC";
}
if (dilations == null) {
dilations = [1, 1, 1];
} else {
deprecationWarn("dilations is deprecated, this field will be gone in v3.0.0.");
}
var $x = convertToTensor(x, "x", "maxPool3d");
var x5D = $x;
var reshapedTo5D = false;
if ($x.rank === 4) {
reshapedTo5D = true;
x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
}
assert(x5D.rank === 5, function() {
return "Error in maxPool3d: x must be rank 5 but got rank " + x5D.rank + ".";
});
assert(dataFormat === "NDHWC", function() {
return "Error in maxPool3d: Only NDHWC is currently supported, " + ("but got dataFormat of " + dataFormat);
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in maxPool3d: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in maxPool3d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2, save) {
if (dilations == null) {
dilations = [1, 1, 1];
}
var convInfo = computePool3DInfo(x5D.shape, filterSize, strides, dilations, pad2, dimRoundingMode, dataFormat);
var y = backend2.maxPool3d(x5D, convInfo);
save([x5D, y]);
return y;
};
var inputs = {x: x5D};
var attrs = {filterSize, strides, pad: pad2, dimRoundingMode, dataFormat, dilations};
var res = ENGINE.runKernelFunc(forward, inputs, null, MaxPool3D, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var maxPool3d = op({maxPool3d_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolWithArgmax_(x, filterSize, strides, pad2, includeBatchInIndex) {
if (includeBatchInIndex === void 0) {
includeBatchInIndex = false;
}
var $x = convertToTensor(x, "x", "maxPoolWithArgmax");
var inputs = {x: $x};
var attrs = {filterSize, strides, pad: pad2, includeBatchInIndex};
var result = ENGINE.runKernel(MaxPoolWithArgmax, inputs, attrs);
return {result: result[0], indexes: result[1]};
}
var maxPoolWithArgmax = op({maxPoolWithArgmax_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function zeros(shape, dtype) {
if (dtype === void 0) {
dtype = "float32";
}
if (dtype === "complex64") {
var real2 = zeros(shape, "float32");
var imag2 = zeros(shape, "float32");
return complex(real2, imag2);
}
var values = makeZerosTypedArray(sizeFromShape(shape), dtype);
return ENGINE.makeTensor(values, shape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function ones$1(shape, dtype) {
if (dtype === void 0) {
dtype = "float32";
}
if (dtype === "complex64") {
var real2 = ones$1(shape, "float32");
var imag2 = zeros(shape, "float32");
return complex(real2, imag2);
}
var values = makeOnesTypedArray(sizeFromShape(shape), dtype);
return ENGINE.makeTensor(values, shape, dtype);
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function mean_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, "x", "mean");
var axes = parseAxisParam(axis, $x.shape);
var shapes = computeOutAndReduceShapes($x.shape, axes);
var reduceShape = shapes[1];
var reduceSize = sizeFromShape(reduceShape);
var inputs = {x: $x};
var attrs = {axis, keepDims};
var forward = function() {
var reduceSizeScalar = scalar(reduceSize);
var xReduce = reduceSizeScalar.dtype === $x.dtype ? $x : cast($x, reduceSizeScalar.dtype);
var res = div(xReduce, reduceSizeScalar);
return sum$1(res, axis, keepDims);
};
var customOp = customGrad(function(x2) {
var value = ENGINE.runKernelFunc(forward, inputs, null, Mean, attrs);
var gradFunc = function(dy) {
var expandedDyShape = x2.shape.slice();
axes.forEach(function(axis2) {
expandedDyShape[axis2] = 1;
});
var expandedDy = reshape(dy, expandedDyShape);
var derX = div(mul(expandedDy, ones$1(x2.shape, "float32")), reduceSize);
return derX;
};
return {value, gradFunc};
});
return customOp($x);
}
var mean = op({mean_});
function min_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, "x", "min");
var forward = function(backend2, save) {
var origAxes = parseAxisParam(axis, $x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, $x.rank);
var minInput = $x;
if (permutedAxes != null) {
minInput = transpose($x, permutedAxes);
axes = getInnerMostAxes(axes.length, $x.rank);
}
var y = backend2.min(minInput, axes);
if (permutedAxes != null) {
minInput.dispose();
}
var res = y;
if (keepDims) {
var expandedShape = expandShapeToKeepDim(res.shape, origAxes);
res = reshape(y, expandedShape);
y.dispose();
}
save([$x, res]);
return res;
};
var inputs = {x: $x};
var attrs = {axis, keepDims};
return ENGINE.runKernelFunc(forward, inputs, null, Min, attrs);
}
var min = op({min_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function minimum_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "minimum");
var $b = convertToTensor(b, "b", "minimum");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
if ($a.dtype === "bool") {
$a = cast($a, "int32");
$b = cast($b, "int32");
}
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2, save) {
var res = backend2.minimum($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Minimum);
}
var minimum = op({minimum_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function mirrorPad_(x, paddings, mode) {
assert(mode === "reflect" || mode === "symmetric", function() {
return "Invalid mode. Mode must be either reflect or symmetric. " + ("Got " + mode + ".");
});
var $x = convertToTensor(x, "x", "mirrorPad");
if ($x.rank === 0) {
throw new Error("mirrorPad(scalar) is not defined. Pass non-scalar to mirrorPad");
}
assert(paddings.length === $x.rank, function() {
return "Padding doesn't match input. Must be " + $x.rank + ". " + ("Got " + paddings.length + ".");
});
var shapeOffset = mode === "reflect" ? 1 : 0;
var _loop_1 = function(i2) {
assert(paddings[i2].length === 2, function() {
return "Invalid number of paddings. Must be length of 2 each.";
});
assert(paddings[i2][0] >= 0 && paddings[i2][0] <= $x.shape[i2] - shapeOffset && paddings[i2][1] >= 0 && paddings[i2][1] <= $x.shape[i2] - shapeOffset, function() {
return "Padding in dimension " + i2 + " cannot be greater than or equal " + ("to " + ($x.shape[i2] - shapeOffset) + " or less than 0 for input of ") + ("shape " + $x.shape);
});
};
for (var i = 0; i < $x.rank; i++) {
_loop_1(i);
}
var attrs = {paddings, mode};
var inputs = {x: $x};
return ENGINE.runKernel(MirrorPad, inputs, attrs);
}
var mirrorPad = op({mirrorPad_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function mod_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "mod");
var $b = convertToTensor(b, "b", "mod");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
var forward = function(backend2, save) {
var res = backend2.mod($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, Mod);
}
var mod = op({mod_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function square_(x) {
var $x = convertToTensor(x, "x", "square");
var attrs = {};
var inputsToSave = [$x];
var outputsToSave = [];
return ENGINE.runKernelFunc(function(backend2, save) {
save([$x]);
return backend2.square($x);
}, {x: $x}, null, "Square", attrs, inputsToSave, outputsToSave);
}
var square = op({square_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function moments_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
x = convertToTensor(x, "x", "moments");
var axes = parseAxisParam(axis, x.shape);
var xMean = mean(x, axes, keepDims);
var keepDimsShape = xMean.shape;
if (!keepDims) {
keepDimsShape = expandShapeToKeepDim(xMean.shape, axes);
}
var devSquared = square(sub(cast(x, "float32"), reshape(xMean, keepDimsShape)));
var variance = mean(devSquared, axes, keepDims);
return {mean: xMean, variance};
}
var moments = op({moments_});
function multiRNNCell_(lstmCells, data, c, h) {
var $data = convertToTensor(data, "data", "multiRNNCell");
var $c = convertToTensorArray(c, "c", "multiRNNCell");
var $h = convertToTensorArray(h, "h", "multiRNNCell");
var input = $data;
var newStates = [];
for (var i = 0; i < lstmCells.length; i++) {
var output = lstmCells[i](input, $c[i], $h[i]);
newStates.push(output[0]);
newStates.push(output[1]);
input = output[1];
}
var newC = [];
var newH = [];
for (var i = 0; i < newStates.length; i += 2) {
newC.push(newStates[i]);
newH.push(newStates[i + 1]);
}
return [newC, newH];
}
var multiRNNCell = op({multiRNNCell_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function multinomial_(logits, numSamples, seed, normalized) {
if (normalized === void 0) {
normalized = false;
}
var $logits = convertToTensor(logits, "logits", "multinomial");
var numOutcomes = $logits.size;
var origRank = $logits.rank;
if (numOutcomes < 2) {
throw new Error("Error in multinomial: you need at least 2 outcomes, but got " + (numOutcomes + "."));
}
if (origRank > 2) {
throw new Error("Rank of probabilities must be 1 or 2, but is " + origRank);
}
seed = seed || Math.random();
var logits2D = origRank === 1 ? reshape($logits, [1, -1]) : $logits;
var res = ENGINE.runKernelFunc(function(backend2) {
return backend2.multinomial(logits2D, normalized, numSamples, seed);
}, {logits2D});
return origRank === 1 ? reshape(res, [res.size]) : res;
}
var multinomial = op({multinomial_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function notEqual_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "notEqual");
var $b = convertToTensor(b, "b", "notEqual");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2) {
return backend2.notEqual($a, $b);
};
var inputs = {a: $a, b: $b};
return ENGINE.runKernelFunc(forward, inputs, null, NotEqual);
}
var notEqual = op({notEqual_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function real_(input) {
var $input = convertToTensor(input, "input", "real");
var forward = function(backend2) {
return backend2.real($input);
};
var inputs = {input: $input};
return ENGINE.runKernelFunc(forward, inputs, null, Real);
}
var real = op({real_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function onesLike_(x) {
var $x = convertToTensor(x, "x", "onesLike");
var forward = function(backend2, save) {
if ($x.dtype === "complex64") {
var r = onesLike(real($x));
var i = zerosLike(imag($x));
return complex(r, i);
}
return backend2.onesLike($x);
};
var inputs = {x: $x};
return ENGINE.runKernelFunc(forward, inputs, null, OnesLike);
}
var onesLike = op({onesLike_});
function outerProduct_(v1, v2) {
var $v1 = convertToTensor(v1, "v1", "outerProduct");
var $v2 = convertToTensor(v2, "v2", "outerProduct");
assert($v1.rank === 1 && $v2.rank === 1, function() {
return "Error in outerProduct: inputs must be rank 1, but got ranks " + ($v1.rank + " and " + $v2.rank + ".");
});
var v12D = reshape($v1, [-1, 1]);
var v22D = reshape($v2, [1, -1]);
return matMul(v12D, v22D);
}
var outerProduct = op({outerProduct_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function pad_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
var $x = convertToTensor(x, "x", "pad");
if ($x.rank === 0) {
throw new Error("pad(scalar) is not defined. Pass non-scalar to pad");
}
var forward = function(backend2, save) {
save([$x]);
return backend2.pad($x, paddings, constantValue);
};
var attrs = {paddings, constantValue};
var inputs = {x: $x};
return ENGINE.runKernelFunc(forward, inputs, null, PadV2, attrs);
}
var pad = op({pad_});
function pad1d_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
assert(paddings.length === 2, function() {
return "Invalid number of paddings. Must be length of 2.";
});
return pad(x, [paddings], constantValue);
}
var pad1d = op({pad1d_});
function pad2d_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
assert(paddings.length === 2 && paddings[0].length === 2 && paddings[1].length === 2, function() {
return "Invalid number of paddings. Must be length of 2 each.";
});
return pad(x, paddings, constantValue);
}
var pad2d = op({pad2d_});
function pad3d_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
assert(paddings.length === 3 && paddings[0].length === 2 && paddings[1].length === 2 && paddings[2].length === 2, function() {
return "Invalid number of paddings. Must be length of 2 each.";
});
return pad(x, paddings, constantValue);
}
var pad3d = op({pad3d_});
function pad4d_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
assert(paddings.length === 4 && paddings[0].length === 2 && paddings[1].length === 2 && paddings[2].length === 2 && paddings[3].length === 2, function() {
return "Invalid number of paddings. Must be length of 2 each.";
});
return pad(x, paddings, constantValue);
}
var pad4d = op({pad4d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function spaceToBatchND_(x, blockShape, paddings) {
var $x = convertToTensor(x, "x", "spaceToBatchND");
assert($x.rank >= 1 + blockShape.length, function() {
return "input rank " + $x.rank + " should be > than [blockShape] " + blockShape.length;
});
assert(paddings.length === blockShape.length, function() {
return "paddings.shape[0] " + paddings.length + " must be equal to [blockShape] " + blockShape.length;
});
assert($x.shape.reduce(function(a, b, i) {
if (i > 0 && i <= blockShape.length) {
return a && (b + paddings[i - 1][0] + paddings[i - 1][1]) % blockShape[i - 1] === 0;
}
return a;
}, true), function() {
return "input spatial dimensions " + $x.shape.slice(1) + " with paddings " + paddings.toString() + " must be divisible by blockShapes " + blockShape.toString();
});
var forward = function(backend2) {
return backend2.spaceToBatchND($x, blockShape, paddings);
};
var inputs = {x: $x};
var attrs = {blockShape, paddings};
return ENGINE.runKernelFunc(forward, inputs, null, SpaceToBatchND, attrs);
}
var spaceToBatchND = op({spaceToBatchND_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function pool_(input, windowShape, poolingType, pad2, dilations, strides) {
if (dilations == null) {
dilations = [1, 1];
}
if (strides == null) {
strides = 1;
}
if (pad2 === 0) {
pad2 = "valid";
}
var $x = convertToTensor(input, "x", "maxPool");
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in pool: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = computePool2DInfo(x4D.shape, windowShape, strides, dilations, pad2);
var dilation = [convInfo.dilationHeight, convInfo.dilationWidth];
var basePadding;
if (pad2 === "same") {
basePadding = withSpaceToBatchBasePaddings([convInfo.filterHeight, convInfo.filterWidth], dilation);
} else {
basePadding = [[0, 0], [0, 0]];
}
var isDilationOne = dilation[0] === 1 && dilation[1] === 1;
var _a = requiredSpaceToBatchPaddings([convInfo.inHeight, convInfo.inWidth], dilation, basePadding), adjustedPadding = _a[0], adjustedCrops = _a[1];
var convertedPad = isDilationOne ? pad2 : "valid";
var convertedX = isDilationOne ? x4D : spaceToBatchND(x4D, dilation, adjustedPadding);
var forwardOp = poolingType === "avg" ? function() {
return avgPool(convertedX, windowShape, strides, convertedPad);
} : function() {
return maxPool(convertedX, windowShape, strides, convertedPad);
};
var y = forwardOp();
var res = isDilationOne ? y : batchToSpaceND(y, dilation, adjustedCrops);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
function requiredSpaceToBatchPaddings(inputShape, blockShape, basePadding) {
var padStart = basePadding.map(function(b) {
return b[0];
});
var origPadEnd = basePadding.map(function(b) {
return b[1];
});
var fullInputShape = inputShape.concat(padStart, origPadEnd);
var padEndExtra = blockShape.map(function(b, i) {
return (b - fullInputShape[i] % b) % b;
});
var padEnd = origPadEnd.map(function(s, i) {
return s + padEndExtra[i];
});
var paddings = blockShape.map(function(_, i) {
return [padStart[i], padEnd[i]];
});
var crops = blockShape.map(function(_, i) {
return [0, padEndExtra[i]];
});
return [paddings, crops];
}
function withSpaceToBatchBasePaddings(filterShape, dilation) {
var dilatedFilterShape = filterShape.map(function(s, i) {
return s + (s - 1) * (dilation[i] - 1);
});
var padExtraShape = dilatedFilterShape.map(function(s) {
return s - 1;
});
var padExtraStart = padExtraShape.map(function(s) {
return Math.floor(s / 2);
});
var padExtraEnd = padExtraShape.map(function(s, i) {
return s - padExtraStart[i];
});
return padExtraShape.map(function(_, i) {
return [padExtraStart[i], padExtraEnd[i]];
});
}
var pool = op({pool_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function pow_(base, exp2) {
var _a;
var $base = convertToTensor(base, "base", "pow");
var $exp = convertToTensor(exp2, "exp", "pow");
_a = makeTypesMatch($base, $exp), $base = _a[0], $exp = _a[1];
var inputs = {a: $base, b: $exp};
var forward = function(backend2, save) {
var y = backend2.pow($base, $exp);
save([$base, $exp, y]);
return y;
};
return ENGINE.runKernelFunc(forward, inputs, null, Pow);
}
var pow = op({pow_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function prelu_(x, alpha) {
var $x = convertToTensor(x, "x", "prelu");
var $alpha = convertToTensor(alpha, "alpha", "prelu");
var forward = function(backend2, save) {
var res = backend2.prelu($x, $alpha);
save([$x, $alpha]);
return res;
};
var inputs = {x: $x, alpha: $alpha};
return ENGINE.runKernelFunc(forward, inputs, null, Prelu);
}
var prelu = op({prelu_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function prod_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, "x", "prod");
if ($x.dtype === "bool") {
$x = cast($x, "int32");
}
var forward = function(backend2) {
var axes = parseAxisParam(axis, $x.shape);
var permutation = getAxesPermutation(axes, $x.rank);
var reductionAxes = axes;
var permutedX = $x;
if (permutation != null) {
permutedX = transpose($x, permutation);
reductionAxes = getInnerMostAxes(reductionAxes.length, $x.rank);
}
var value = backend2.prod(permutedX, reductionAxes);
if (keepDims) {
var newShape = expandShapeToKeepDim(value.shape, axes);
value = reshape(value, newShape);
}
return value;
};
var inputs = {x: $x};
var attrs = {axis, keepDims};
return ENGINE.runKernelFunc(forward, inputs, null, Prod, attrs);
}
var prod = op({prod_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function rand_(shape, randFunction, dtype) {
var size = sizeFromShape(shape);
var values = null;
if (dtype == null || dtype === "float32") {
values = new Float32Array(size);
} else if (dtype === "int32") {
values = new Int32Array(size);
} else if (dtype === "bool") {
values = new Uint8Array(size);
} else {
throw new Error("Unknown data type " + dtype);
}
for (var i = 0; i < size; i++) {
values[i] = randFunction();
}
return ENGINE.makeTensor(values, shape, dtype);
}
var rand = op({rand_});
var commonjsGlobal = typeof globalThis !== "undefined" ? globalThis : typeof window !== "undefined" ? window : typeof global !== "undefined" ? global : typeof self !== "undefined" ? self : {};
function createCommonjsModule(fn, module2) {
return module2 = {exports: {}}, fn(module2, module2.exports), module2.exports;
}
var alea = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function Alea(seed) {
var me = this, mash = Mash();
me.next = function() {
var t = 2091639 * me.s0 + me.c * 23283064365386963e-26;
me.s0 = me.s1;
me.s1 = me.s2;
return me.s2 = t - (me.c = t | 0);
};
me.c = 1;
me.s0 = mash(" ");
me.s1 = mash(" ");
me.s2 = mash(" ");
me.s0 -= mash(seed);
if (me.s0 < 0) {
me.s0 += 1;
}
me.s1 -= mash(seed);
if (me.s1 < 0) {
me.s1 += 1;
}
me.s2 -= mash(seed);
if (me.s2 < 0) {
me.s2 += 1;
}
mash = null;
}
function copy(f, t) {
t.c = f.c;
t.s0 = f.s0;
t.s1 = f.s1;
t.s2 = f.s2;
return t;
}
function impl(seed, opts) {
var xg = new Alea(seed), state = opts && opts.state, prng = xg.next;
prng.int32 = function() {
return xg.next() * 4294967296 | 0;
};
prng.double = function() {
return prng() + (prng() * 2097152 | 0) * 11102230246251565e-32;
};
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
function Mash() {
var n = 4022871197;
var mash = function(data) {
data = data.toString();
for (var i = 0; i < data.length; i++) {
n += data.charCodeAt(i);
var h = 0.02519603282416938 * n;
n = h >>> 0;
h -= n;
h *= n;
n = h >>> 0;
h -= n;
n += h * 4294967296;
}
return (n >>> 0) * 23283064365386963e-26;
};
return mash;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.alea = impl;
}
})(commonjsGlobal, module2, false);
});
var xor128 = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.next = function() {
var t = me.x ^ me.x << 11;
me.x = me.y;
me.y = me.z;
me.z = me.w;
return me.w ^= me.w >>> 19 ^ t ^ t >>> 8;
};
if (seed === (seed | 0)) {
me.x = seed;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xor128 = impl;
}
})(commonjsGlobal, module2, false);
});
var xorwow = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.next = function() {
var t = me.x ^ me.x >>> 2;
me.x = me.y;
me.y = me.z;
me.z = me.w;
me.w = me.v;
return (me.d = me.d + 362437 | 0) + (me.v = me.v ^ me.v << 4 ^ (t ^ t << 1)) | 0;
};
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.v = 0;
if (seed === (seed | 0)) {
me.x = seed;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
if (k == strseed.length) {
me.d = me.x << 10 ^ me.x >>> 4;
}
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
t.v = f.v;
t.d = f.d;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xorwow = impl;
}
})(commonjsGlobal, module2, false);
});
var xorshift7 = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this;
me.next = function() {
var X = me.x, i = me.i, t, v;
t = X[i];
t ^= t >>> 7;
v = t ^ t << 24;
t = X[i + 1 & 7];
v ^= t ^ t >>> 10;
t = X[i + 3 & 7];
v ^= t ^ t >>> 3;
t = X[i + 4 & 7];
v ^= t ^ t << 7;
t = X[i + 7 & 7];
t = t ^ t << 13;
v ^= t ^ t << 9;
X[i] = v;
me.i = i + 1 & 7;
return v;
};
function init(me2, seed2) {
var j, w, X = [];
if (seed2 === (seed2 | 0)) {
w = X[0] = seed2;
} else {
seed2 = "" + seed2;
for (j = 0; j < seed2.length; ++j) {
X[j & 7] = X[j & 7] << 15 ^ seed2.charCodeAt(j) + X[j + 1 & 7] << 13;
}
}
while (X.length < 8)
X.push(0);
for (j = 0; j < 8 && X[j] === 0; ++j)
;
if (j == 8)
w = X[7] = -1;
else
w = X[j];
me2.x = X;
me2.i = 0;
for (j = 256; j > 0; --j) {
me2.next();
}
}
init(me, seed);
}
function copy(f, t) {
t.x = f.x.slice();
t.i = f.i;
return t;
}
function impl(seed, opts) {
if (seed == null)
seed = +new Date();
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.x)
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xorshift7 = impl;
}
})(commonjsGlobal, module2, false);
});
var xor4096 = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this;
me.next = function() {
var w = me.w, X = me.X, i = me.i, t, v;
me.w = w = w + 1640531527 | 0;
v = X[i + 34 & 127];
t = X[i = i + 1 & 127];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
v = X[i] = v ^ t;
me.i = i;
return v + (w ^ w >>> 16) | 0;
};
function init(me2, seed2) {
var t, v, i, j, w, X = [], limit = 128;
if (seed2 === (seed2 | 0)) {
v = seed2;
seed2 = null;
} else {
seed2 = seed2 + "\0";
v = 0;
limit = Math.max(limit, seed2.length);
}
for (i = 0, j = -32; j < limit; ++j) {
if (seed2)
v ^= seed2.charCodeAt((j + 32) % seed2.length);
if (j === 0)
w = v;
v ^= v << 10;
v ^= v >>> 15;
v ^= v << 4;
v ^= v >>> 13;
if (j >= 0) {
w = w + 1640531527 | 0;
t = X[j & 127] ^= v + w;
i = t == 0 ? i + 1 : 0;
}
}
if (i >= 128) {
X[(seed2 && seed2.length || 0) & 127] = -1;
}
i = 127;
for (j = 4 * 128; j > 0; --j) {
v = X[i + 34 & 127];
t = X[i = i + 1 & 127];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
X[i] = v ^ t;
}
me2.w = w;
me2.X = X;
me2.i = i;
}
init(me, seed);
}
function copy(f, t) {
t.i = f.i;
t.w = f.w;
t.X = f.X.slice();
return t;
}
function impl(seed, opts) {
if (seed == null)
seed = +new Date();
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.X)
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xor4096 = impl;
}
})(commonjsGlobal, module2, false);
});
var tychei = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.next = function() {
var b = me.b, c = me.c, d = me.d, a = me.a;
b = b << 25 ^ b >>> 7 ^ c;
c = c - d | 0;
d = d << 24 ^ d >>> 8 ^ a;
a = a - b | 0;
me.b = b = b << 20 ^ b >>> 12 ^ c;
me.c = c = c - d | 0;
me.d = d << 16 ^ c >>> 16 ^ a;
return me.a = a - b | 0;
};
me.a = 0;
me.b = 0;
me.c = 2654435769 | 0;
me.d = 1367130551;
if (seed === Math.floor(seed)) {
me.a = seed / 4294967296 | 0;
me.b = seed | 0;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 20; k++) {
me.b ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.a = f.a;
t.b = f.b;
t.c = f.c;
t.d = f.d;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.tychei = impl;
}
})(commonjsGlobal, module2, false);
});
var seedrandom = createCommonjsModule(function(module2) {
(function(pool2, math2) {
var global2 = this, width = 256, chunks = 6, digits = 52, rngname = "random", startdenom = math2.pow(width, chunks), significance = math2.pow(2, digits), overflow = significance * 2, mask = width - 1, nodecrypto;
function seedrandom2(seed, options, callback) {
var key = [];
options = options == true ? {entropy: true} : options || {};
var shortseed = mixkey(flatten2(options.entropy ? [seed, tostring(pool2)] : seed == null ? autoseed() : seed, 3), key);
var arc4 = new ARC4(key);
var prng = function() {
var n = arc4.g(chunks), d = startdenom, x = 0;
while (n < significance) {
n = (n + x) * width;
d *= width;
x = arc4.g(1);
}
while (n >= overflow) {
n /= 2;
d /= 2;
x >>>= 1;
}
return (n + x) / d;
};
prng.int32 = function() {
return arc4.g(4) | 0;
};
prng.quick = function() {
return arc4.g(4) / 4294967296;
};
prng.double = prng;
mixkey(tostring(arc4.S), pool2);
return (options.pass || callback || function(prng2, seed2, is_math_call, state) {
if (state) {
if (state.S) {
copy(state, arc4);
}
prng2.state = function() {
return copy(arc4, {});
};
}
if (is_math_call) {
math2[rngname] = prng2;
return seed2;
} else
return prng2;
})(prng, shortseed, "global" in options ? options.global : this == math2, options.state);
}
math2["seed" + rngname] = seedrandom2;
function ARC4(key) {
var t, keylen = key.length, me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
if (!keylen) {
key = [keylen++];
}
while (i < width) {
s[i] = i++;
}
for (i = 0; i < width; i++) {
s[i] = s[j = mask & j + key[i % keylen] + (t = s[i])];
s[j] = t;
}
(me.g = function(count) {
var t2, r = 0, i2 = me.i, j2 = me.j, s2 = me.S;
while (count--) {
t2 = s2[i2 = mask & i2 + 1];
r = r * width + s2[mask & (s2[i2] = s2[j2 = mask & j2 + t2]) + (s2[j2] = t2)];
}
me.i = i2;
me.j = j2;
return r;
})(width);
}
function copy(f, t) {
t.i = f.i;
t.j = f.j;
t.S = f.S.slice();
return t;
}
function flatten2(obj, depth) {
var result = [], typ = typeof obj, prop;
if (depth && typ == "object") {
for (prop in obj) {
try {
result.push(flatten2(obj[prop], depth - 1));
} catch (e) {
}
}
}
return result.length ? result : typ == "string" ? obj : obj + "\0";
}
function mixkey(seed, key) {
var stringseed = seed + "", smear, j = 0;
while (j < stringseed.length) {
key[mask & j] = mask & (smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++);
}
return tostring(key);
}
function autoseed() {
try {
var out;
if (nodecrypto && (out = nodecrypto.randomBytes)) {
out = out(width);
} else {
out = new Uint8Array(width);
(global2.crypto || global2.msCrypto).getRandomValues(out);
}
return tostring(out);
} catch (e) {
var browser2 = global2.navigator, plugins = browser2 && browser2.plugins;
return [+new Date(), global2, plugins, global2.screen, tostring(pool2)];
}
}
function tostring(a) {
return String.fromCharCode.apply(0, a);
}
mixkey(math2.random(), pool2);
if (module2.exports) {
module2.exports = seedrandom2;
try {
nodecrypto = require_crypto();
} catch (ex) {
}
}
})([], Math);
});
seedrandom.alea = alea;
seedrandom.xor128 = xor128;
seedrandom.xorwow = xorwow;
seedrandom.xorshift7 = xorshift7;
seedrandom.xor4096 = xor4096;
seedrandom.tychei = tychei;
var seedrandom$1 = seedrandom;
var seedrandom_1 = seedrandom$1.alea;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MPRandGauss = function() {
function MPRandGauss2(mean2, stdDeviation, dtype, truncated, seed) {
this.mean = mean2;
this.stdDev = stdDeviation;
this.dtype = dtype;
this.nextVal = NaN;
this.truncated = truncated;
if (this.truncated) {
this.upper = this.mean + this.stdDev * 2;
this.lower = this.mean - this.stdDev * 2;
}
var seedValue = seed ? seed : Math.random();
this.random = seedrandom_1(seedValue.toString());
}
MPRandGauss2.prototype.nextValue = function() {
if (!isNaN(this.nextVal)) {
var value = this.nextVal;
this.nextVal = NaN;
return value;
}
var resultX, resultY;
var isValid = false;
while (!isValid) {
var v1 = void 0, v2 = void 0, s = void 0;
do {
v1 = 2 * this.random() - 1;
v2 = 2 * this.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s >= 1 || s === 0);
var mul2 = Math.sqrt(-2 * Math.log(s) / s);
resultX = this.mean + this.stdDev * v1 * mul2;
resultY = this.mean + this.stdDev * v2 * mul2;
if (!this.truncated || this.isValidTruncated(resultX)) {
isValid = true;
}
}
if (!this.truncated || this.isValidTruncated(resultY)) {
this.nextVal = this.convertValue(resultY);
}
return this.convertValue(resultX);
};
MPRandGauss2.prototype.convertValue = function(value) {
if (this.dtype == null || this.dtype === "float32") {
return value;
}
return Math.round(value);
};
MPRandGauss2.prototype.isValidTruncated = function(value) {
return value <= this.upper && value >= this.lower;
};
return MPRandGauss2;
}();
var RandGamma = function() {
function RandGamma2(alpha, beta, dtype, seed) {
this.alpha = alpha;
this.beta = 1 / beta;
this.dtype = dtype;
var seedValue = seed ? seed : Math.random();
this.randu = seedrandom_1(seedValue.toString());
this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());
if (alpha < 1) {
this.d = alpha + 2 / 3;
} else {
this.d = alpha - 1 / 3;
}
this.c = 1 / Math.sqrt(9 * this.d);
}
RandGamma2.prototype.nextValue = function() {
var x2, v0, v1, x, u, v;
while (true) {
do {
x = this.randn.nextValue();
v = 1 + this.c * x;
} while (v <= 0);
v *= v * v;
x2 = x * x;
v0 = 1 - 0.331 * x2 * x2;
v1 = 0.5 * x2 + this.d * (1 - v + Math.log(v));
u = this.randu();
if (u < v0 || Math.log(u) < v1) {
break;
}
}
v = 1 / this.beta * this.d * v;
if (this.alpha < 1) {
v *= Math.pow(this.randu(), 1 / this.alpha);
}
return this.convertValue(v);
};
RandGamma2.prototype.convertValue = function(value) {
if (this.dtype === "float32") {
return value;
}
return Math.round(value);
};
return RandGamma2;
}();
var UniformRandom = function() {
function UniformRandom2(min2, max2, dtype, seed) {
var _this = this;
if (min2 === void 0) {
min2 = 0;
}
if (max2 === void 0) {
max2 = 1;
}
this.canReturnFloat = function() {
return _this.dtype == null || _this.dtype === "float32";
};
this.min = min2;
this.range = max2 - min2;
this.dtype = dtype;
if (seed == null) {
seed = Math.random();
}
if (typeof seed === "number") {
seed = seed.toString();
}
if (!this.canReturnFloat() && this.range <= 1) {
throw new Error("The difference between " + min2 + " - " + max2 + " <= 1 and dtype is not float");
}
this.random = seedrandom_1(seed);
}
UniformRandom2.prototype.convertValue = function(value) {
if (this.canReturnFloat()) {
return value;
}
return Math.round(value);
};
UniformRandom2.prototype.nextValue = function() {
return this.convertValue(this.min + this.range * this.random());
};
return UniformRandom2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function randomGamma_(shape, alpha, beta, dtype, seed) {
if (beta === void 0) {
beta = 1;
}
if (dtype === void 0) {
dtype = "float32";
}
if (beta == null) {
beta = 1;
}
if (dtype == null) {
dtype = "float32";
}
if (dtype !== "float32" && dtype !== "int32") {
throw new Error("Unsupported data type " + dtype);
}
var rgamma = new RandGamma(alpha, beta, dtype, seed);
var res = buffer(shape, dtype);
for (var i = 0; i < res.values.length; i++) {
res.values[i] = rgamma.nextValue();
}
return res.toTensor();
}
var randomGamma = op({randomGamma_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function randomNormal_(shape, mean2, stdDev, dtype, seed) {
if (mean2 === void 0) {
mean2 = 0;
}
if (stdDev === void 0) {
stdDev = 1;
}
if (dtype != null && dtype === "bool") {
throw new Error("Unsupported data type " + dtype);
}
var randGauss = new MPRandGauss(mean2, stdDev, dtype, false, seed);
var res = buffer(shape, dtype);
for (var i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}
var randomNormal = op({randomNormal_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function randomUniform_(shape, minval, maxval, dtype, seed) {
if (minval === void 0) {
minval = 0;
}
if (maxval === void 0) {
maxval = 1;
}
if (dtype === void 0) {
dtype = "float32";
}
var res = buffer(shape, dtype);
var random = new UniformRandom(minval, maxval, null, seed);
for (var i = 0; i < res.values.length; i++) {
res.values[i] = random.nextValue();
}
return res.toTensor();
}
var randomUniform = op({randomUniform_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tensor1d(values, dtype) {
assertNonNull(values);
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 1) {
throw new Error("tensor1d() requires values to be a flat/TypedArray");
}
var shape = null;
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function range(start, stop, step2, dtype) {
if (step2 === void 0) {
step2 = 1;
}
if (dtype === void 0) {
dtype = "float32";
}
if (step2 === 0) {
throw new Error("Cannot have a step of zero");
}
var forward = function() {
var sameStartStop = start === stop;
var increasingRangeNegativeStep = start < stop && step2 < 0;
var decreasingRangePositiveStep = stop < start && step2 > 1;
if (sameStartStop || increasingRangeNegativeStep || decreasingRangePositiveStep) {
return zeros([0], dtype);
}
var numElements = Math.abs(Math.ceil((stop - start) / step2));
var values = makeZerosTypedArray(numElements, dtype);
if (stop < start && step2 === 1) {
step2 = -1;
}
values[0] = start;
for (var i = 1; i < values.length; i++) {
values[i] = values[i - 1] + step2;
}
return tensor1d(values, dtype);
};
var attrs = {start, stop, step: step2, dtype};
return ENGINE.runKernelFunc(forward, {}, null, Range, attrs);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reciprocal_(x) {
var $x = convertToTensor(x, "x", "reciprocal");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.reciprocal($x);
save([$x]);
return res;
}, inputs, null, Reciprocal);
}
var reciprocal = op({reciprocal_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function relu_(x) {
var $x = convertToTensor(x, "x", "relu");
var forward = function(backend2, save) {
save([$x]);
if ($x.dtype === "bool") {
return cast($x, "int32");
}
return backend2.relu($x);
};
var inputs = {x: $x};
return ENGINE.runKernelFunc(forward, inputs, null, Relu);
}
var relu = op({relu_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function relu6_(x) {
var $x = convertToTensor(x, "x", "relu6");
var forward = function(backend2, save) {
save([$x]);
if ($x.dtype === "bool") {
return cast($x, "int32");
}
return backend2.relu6($x);
};
var inputs = {x: $x};
return ENGINE.runKernelFunc(forward, inputs, null, Relu6);
}
var relu6 = op({relu6_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reverse_(x, axis) {
var $x = convertToTensor(x, "x", "reverse");
var forward = function(backend2) {
var axes = parseAxisParam(axis, $x.shape);
if ($x.rank === 0) {
return clone($x);
}
var res = backend2.reverse($x, axes);
return reshape(res, $x.shape);
};
var inputs = {x: $x};
var attrs = {dims: axis};
return ENGINE.runKernelFunc(forward, inputs, null, Reverse, attrs);
}
var reverse = op({reverse_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reverse1d_(x) {
var $x = convertToTensor(x, "x", "reverse");
assert($x.rank === 1, function() {
return "Error in reverse1D: x must be rank 1 but got rank " + $x.rank + ".";
});
return reverse($x, 0);
}
var reverse1d = op({reverse1d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reverse2d_(x, axis) {
var $x = convertToTensor(x, "x", "reverse");
assert($x.rank === 2, function() {
return "Error in reverse2D: x must be rank 2 but got rank " + $x.rank + ".";
});
return reverse($x, axis);
}
var reverse2d = op({reverse2d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reverse3d_(x, axis) {
var $x = convertToTensor(x, "x", "reverse");
assert($x.rank === 3, function() {
return "Error in reverse3D: x must be rank 3 but got rank " + $x.rank + ".";
});
return reverse($x, axis);
}
var reverse3d = op({reverse3d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reverse4d_(x, axis) {
var $x = convertToTensor(x, "x", "reverse");
assert($x.rank === 4, function() {
return "Error in reverse4D: x must be rank 4 but got rank " + $x.rank + ".";
});
return reverse($x, axis);
}
var reverse4d = op({reverse4d_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function round_(x) {
var $x = convertToTensor(x, "x", "round");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.round($x);
}, inputs, null, Round);
}
var round = op({round_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function rsqrt_(x) {
var $x = convertToTensor(x, "x", "rsqrt");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.rsqrt($x);
save([$x]);
return res;
}, inputs, null, Rsqrt);
}
var rsqrt = op({rsqrt_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function selu_(x) {
var $x = convertToTensor(x, "x", "selu");
var forward = function(backend2, save) {
var res = backend2.selu($x);
save([$x]);
return res;
};
var inputs = {x: $x};
return ENGINE.runKernelFunc(forward, inputs, null, Selu);
}
var selu = op({selu_});
function separableConv2d_(x, depthwiseFilter, pointwiseFilter, strides, pad2, dilation, dataFormat) {
if (dilation === void 0) {
dilation = [1, 1];
}
if (dataFormat === void 0) {
dataFormat = "NHWC";
}
var $x = convertToTensor(x, "x", "separableConv2d");
var $depthwiseFilter = convertToTensor(depthwiseFilter, "depthwiseFilter", "separableConv2d");
var $pointwiseFilter = convertToTensor(pointwiseFilter, "pointwiseFilter", "separableConv2d");
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
if (dataFormat === "NCHW") {
throw new Error("separableConv2d currently does not support dataFormat NCHW; only NHWC is supported");
}
assert(x4D.rank === 4, function() {
return "Error in separableConv2d: input must be rank 4, but got " + ("rank " + x4D.rank + ".");
});
assert($depthwiseFilter.rank === 4, function() {
return "Error in separableConv2d: depthwise filter must be rank 4, but " + ("got rank " + $depthwiseFilter.rank + ".");
});
assert($pointwiseFilter.rank === 4, function() {
return "Error in separableConv2d: pointwise filter must be rank 4, but " + ("got rank " + $depthwiseFilter.rank + ".");
});
assert($pointwiseFilter.shape[0] === 1, function() {
return "Error in separableConv2d: the first dimension of pointwise filter " + (" must be 1, but got " + $pointwiseFilter.shape[0] + ".");
});
assert($pointwiseFilter.shape[1] === 1, function() {
return "Error in separableConv2d: the second dimension of pointwise " + ("filter must be 1, but got " + $pointwiseFilter.shape[1] + ".");
});
var inChannels = $depthwiseFilter.shape[2];
var channelMultiplier = $depthwiseFilter.shape[3];
assert($pointwiseFilter.shape[2] === inChannels * channelMultiplier, function() {
return "Error in separableConv2d: the third dimension of pointwise filter " + ("must be " + inChannels * channelMultiplier + ", ") + ("but got " + $pointwiseFilter.shape[2] + ".");
});
var depthwise = depthwiseConv2d(x4D, $depthwiseFilter, strides, pad2, dataFormat, dilation);
var pointwiseStride = 1;
var res = conv2d(depthwise, $pointwiseFilter, pointwiseStride, "valid", dataFormat);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var separableConv2d = op({separableConv2d_});
function setdiff1dAsync_(x, y) {
return __awaiter(this, void 0, void 0, function() {
var $x, $y, xVals, yVals, ySet, outputSize, i, buffer2, indices, i, p;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
$x = convertToTensor(x, "x", "setdiff1d");
$y = convertToTensor(y, "y", "setdiff1d");
assert($x.dtype === $y.dtype, function() {
return "x and y should have the same dtype, but got x (" + $x.dtype + ") and y (" + $y.dtype + ").";
});
assert($x.rank === 1, function() {
return "x should be 1D tensor, but got x (" + $x.shape + ").";
});
assert($y.rank === 1, function() {
return "y should be 1D tensor, but got y (" + $y.shape + ").";
});
return [4, $x.data()];
case 1:
xVals = _a.sent();
return [4, $y.data()];
case 2:
yVals = _a.sent();
ySet = new Set(yVals);
outputSize = 0;
for (i = 0; i < xVals.length; i++) {
if (!ySet.has(xVals[i])) {
outputSize++;
}
}
buffer2 = new TensorBuffer([outputSize], $x.dtype);
indices = new TensorBuffer([outputSize], "int32");
for (i = 0, p = 0; i < xVals.length; i++) {
if (!ySet.has(xVals[i])) {
buffer2.values[p] = xVals[i];
indices.values[p] = i;
p++;
}
}
return [2, [buffer2.toTensor(), indices.toTensor()]];
}
});
});
}
var setdiff1dAsync = setdiff1dAsync_;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sign_(x) {
var $x = convertToTensor(x, "x", "sign");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.sign($x);
}, inputs, null, Sign);
}
var sign = op({sign_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sin_(x) {
var $x = convertToTensor(x, "x", "sin");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.sin($x);
save([$x]);
return res;
}, inputs, null, Sin);
}
var sin = op({sin_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sinh_(x) {
var $x = convertToTensor(x, "x", "sinh");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.sinh($x);
save([$x]);
return res;
}, inputs, null, Sinh);
}
var sinh = op({sinh_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function slice1d_(x, begin, size) {
var $x = convertToTensor(x, "x", "slice1d");
assert($x.rank === 1, function() {
return "slice1d expects a rank-1 tensor, but got a rank-" + $x.rank + " tensor";
});
return slice($x, [begin], [size]);
}
var slice1d = op({slice1d_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function slice2d_(x, begin, size) {
var $x = convertToTensor(x, "x", "slice2d");
assert($x.rank === 2, function() {
return "slice2d expects a rank-2 tensor, but got a rank-" + $x.rank + " tensor";
});
return slice($x, begin, size);
}
var slice2d = op({slice2d_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function slice3d_(x, begin, size) {
var $x = convertToTensor(x, "x", "slice3d");
assert($x.rank === 3, function() {
return "slice3d expects a rank-3 tensor, but got a rank-" + $x.rank + " tensor";
});
return slice($x, begin, size);
}
var slice3d = op({slice3d_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function slice4d_(x, begin, size) {
var $x = convertToTensor(x, "x", "slice4d");
assert($x.rank === 4, function() {
return "slice4d expects a rank-4 tensor, but got a rank-" + $x.rank + " tensor";
});
return slice($x, begin, size);
}
var slice4d = op({slice4d_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function softmax_(logits, dim) {
if (dim === void 0) {
dim = -1;
}
var $logits = convertToTensor(logits, "logits", "softmax", "float32");
if (dim === -1) {
dim = $logits.rank - 1;
}
if (dim !== $logits.rank - 1) {
throw Error("Softmax along a non-last dimension is not yet supported. " + ("Logits was rank " + $logits.rank + " and dim was " + dim));
}
var inputs = {logits: $logits};
var attrs = {dim};
return ENGINE.runKernelFunc(function(backend2, save) {
var y = backend2.softmax($logits, dim);
save([y]);
return y;
}, inputs, null, Softmax, attrs);
}
var softmax = op({softmax_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fft_(input) {
assert(input.dtype === "complex64", function() {
return "The dtype for tf.spectral.fft() must be complex64 " + ("but got " + input.dtype + ".");
});
var inputs = {input};
return ENGINE.runKernelFunc(function(backend2) {
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = input.size / innerDimensionSize;
var input2D = input.as2D(batch, innerDimensionSize);
var result = backend2.fft(input2D);
return result.reshape(input.shape);
}, inputs, null, FFT);
}
var fft = op({fft_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function ifft_(input) {
assert(input.dtype === "complex64", function() {
return "The dtype for tf.spectral.ifft() must be complex64 " + ("but got " + input.dtype + ".");
});
var inputs = {input};
return ENGINE.runKernelFunc(function(backend2) {
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = input.size / innerDimensionSize;
var input2D = reshape(input, [batch, innerDimensionSize]);
var result = backend2.ifft(input2D);
return reshape(result, input.shape);
}, inputs, null, IFFT);
}
var ifft = op({ifft_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function irfft_(input) {
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = input.size / innerDimensionSize;
var ret;
if (innerDimensionSize <= 2) {
var complexInput = reshape(input, [batch, innerDimensionSize]);
ret = ifft(complexInput);
} else {
var outputShape = [batch, 2 * (innerDimensionSize - 1)];
var realInput = reshape(real(input), [batch, innerDimensionSize]);
var imagInput = reshape(imag(input), [batch, innerDimensionSize]);
var realConjugate = reverse(slice(realInput, [0, 1], [batch, innerDimensionSize - 2]), 1);
var imagConjugate = mul(reverse(slice(imagInput, [0, 1], [batch, innerDimensionSize - 2]), 1), scalar(-1));
var r = concat([realInput, realConjugate], 1);
var i = concat([imagInput, imagConjugate], 1);
var complexInput = reshape(complex(r, i), [outputShape[0], outputShape[1]]);
ret = ifft(complexInput);
}
ret = real(ret);
if (input.rank === 3 && input.shape[0] !== 0) {
var temp = ret;
var batch_1 = input.shape[0];
ret = reshape(ret, [batch_1, ret.shape[0] / batch_1, ret.shape[1]]);
temp.dispose();
}
return ret;
}
var irfft = op({irfft_});
function prepareSplitSize(x, numOrSizeSplits, axis) {
if (axis === void 0) {
axis = 0;
}
var splitSizes = [];
if (typeof numOrSizeSplits === "number") {
assert(x.shape[axis] % numOrSizeSplits === 0, function() {
return "Number of splits must evenly divide the axis.";
});
splitSizes = new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
} else {
var numOfNegs = numOrSizeSplits.reduce(function(count, value) {
if (value === -1) {
count += 1;
}
return count;
}, 0);
assert(numOfNegs <= 1, function() {
return "There should be only one negative value in split array.";
});
var negIndex = numOrSizeSplits.indexOf(-1);
if (negIndex !== -1) {
var total = numOrSizeSplits.reduce(function(a, b) {
return b > 0 ? a + b : a;
});
numOrSizeSplits[negIndex] = x.shape[axis] - total;
}
assert(x.shape[axis] === numOrSizeSplits.reduce(function(a, b) {
return a + b;
}), function() {
return "The sum of sizes must match the size of the axis dimension.";
});
splitSizes = numOrSizeSplits;
}
return splitSizes;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function split_(x, numOrSizeSplits, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, "x", "split");
var forward = function(backend2, _) {
var $axis = parseAxisParam(axis, $x.shape)[0];
var splitSizes = prepareSplitSize($x, numOrSizeSplits, $axis);
return backend2.split($x, splitSizes, $axis);
};
var inputs = {x: $x};
var attr = {numOrSizeSplits, axis};
return ENGINE.runKernelFunc(forward, inputs, null, SplitV, attr);
}
var split = op({split_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function rfft_(input, fftLength) {
assert(input.dtype === "float32", function() {
return "The dtype for rfft() must be real value but got " + input.dtype;
});
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = input.size / innerDimensionSize;
var adjustedInput;
if (fftLength != null && fftLength < innerDimensionSize) {
var begin = input.shape.map(function(v) {
return 0;
});
var size = input.shape.map(function(v) {
return v;
});
size[input.shape.length - 1] = fftLength;
adjustedInput = slice(input, begin, size);
innerDimensionSize = fftLength;
} else if (fftLength != null && fftLength > innerDimensionSize) {
var zerosShape = input.shape.map(function(v) {
return v;
});
zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize;
adjustedInput = concat([input, zeros(zerosShape)], input.shape.length - 1);
innerDimensionSize = fftLength;
} else {
adjustedInput = input;
}
var zerosInput = zerosLike(adjustedInput);
var complexInput = reshape(complex(adjustedInput, zerosInput), [batch, innerDimensionSize]);
var ret = fft(complexInput);
var half = Math.floor(innerDimensionSize / 2) + 1;
var realValues = real(ret);
var imagValues = imag(ret);
var realComplexConjugate = split(realValues, [half, innerDimensionSize - half], realValues.shape.length - 1);
var imagComplexConjugate = split(imagValues, [half, innerDimensionSize - half], imagValues.shape.length - 1);
var outputShape = adjustedInput.shape.slice();
outputShape[adjustedInput.shape.length - 1] = half;
return reshape(complex(realComplexConjugate[0], imagComplexConjugate[0]), outputShape);
}
var rfft = op({rfft_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sqrt_(x) {
var $x = convertToTensor(x, "x", "sqrt");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.sqrt($x);
save([$x]);
return res;
}, inputs, null, Sqrt);
}
var sqrt = op({sqrt_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function squaredDifference_(a, b) {
var _a;
var $a = convertToTensor(a, "a", "squaredDifference");
var $b = convertToTensor(b, "b", "squaredDifference");
_a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var forward = function(backend2, save) {
var res = backend2.squaredDifference($a, $b);
save([$a, $b]);
return res;
};
var inputs = {a: $a, b: $b};
var attrs = {};
return ENGINE.runKernelFunc(forward, inputs, null, SquaredDifference, attrs);
}
var squaredDifference = op({squaredDifference_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function squeeze_(x, axis) {
var $x = convertToTensor(x, "x", "squeeze");
return reshape($x, squeezeShape($x.shape, axis).newShape);
}
var squeeze = op({squeeze_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stack_(tensors, axis) {
if (axis === void 0) {
axis = 0;
}
var $tensors = convertToTensorArray(tensors, "tensors", "stack");
assert($tensors.length >= 1, function() {
return "Pass at least one tensor to tf.stack";
});
if ($tensors.length === 1) {
return expandDims($tensors[0], axis);
}
var rank = $tensors[0].rank;
var shape = $tensors[0].shape;
var dtype = $tensors[0].dtype;
assert(axis <= rank, function() {
return "Axis must be <= rank of the tensor";
});
$tensors.forEach(function(t) {
assertShapesMatch(shape, t.shape, "All tensors passed to stack must have matching shapes");
assert(dtype === t.dtype, function() {
return "All tensors passed to stack must have matching dtypes";
});
});
var expandedTensors = $tensors.map(function(t) {
return expandDims(t, axis);
});
return concat(expandedTensors, axis);
}
var stack = op({stack_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function step_(x, alpha) {
if (alpha === void 0) {
alpha = 0;
}
var $x = convertToTensor(x, "x", "step");
var inputs = {x: $x};
var attrs = {alpha};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.step($x, alpha);
}, inputs, null, Step, attrs);
}
var step = op({step_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stridedSlice_(x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
if (beginMask === void 0) {
beginMask = 0;
}
if (endMask === void 0) {
endMask = 0;
}
if (ellipsisMask === void 0) {
ellipsisMask = 0;
}
if (newAxisMask === void 0) {
newAxisMask = 0;
}
if (shrinkAxisMask === void 0) {
shrinkAxisMask = 0;
}
var $x = convertToTensor(x, "x", "stridedSlice");
var forward = function(backend2) {
if (strides == null) {
strides = new Array(begin.length);
}
var ellipsisAxes = maskToAxes(ellipsisMask);
if (ellipsisAxes.length > 1) {
throw new Error("Multiple ellipses in slice is not allowed.");
}
if (ellipsisMask !== 0 && newAxisMask !== 0) {
throw new Error("Using both ellipsisMask and newAxisMask is not yet supported.");
}
if (ellipsisMask !== 0 && shrinkAxisMask !== 0) {
throw new Error("Using both ellipsisMask and shrinkAxisMask is not yet supported.");
}
var numInterpolatedAxes = $x.rank - begin.length;
var expandAxes = maskToAxes(newAxisMask);
var newShape = $x.shape.slice();
expandAxes.forEach(function(axis) {
begin[axis] = 0;
end[axis] = 1;
newShape.splice(axis, 0, 1);
});
$x = reshape($x, newShape);
var _a = getNormalizedAxes($x.shape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask), normalizedBegin = _a.begin, normalizedEnd = _a.end, normalizedStrides = _a.strides;
begin = normalizedBegin;
end = normalizedEnd;
strides = normalizedStrides;
var shrinkAxes = maskToAxes(shrinkAxisMask);
shrinkAxes.forEach(function(axis) {
end[axis] = begin[axis] + 1;
strides[axis] = 1;
});
var size = computeOutShape(begin, end, strides);
var outShape = size.filter(function(_, axis) {
return shrinkAxes.indexOf(axis) === -1;
});
var nonStrided = strides.every(function(v) {
return v === 1;
});
if (nonStrided) {
return reshape(slice($x, begin, size), outShape);
}
var res = backend2.stridedSlice($x, begin, end, strides);
return reshape(res, outShape);
};
var inputs = {x: $x};
var attrs = {
begin,
end,
strides,
beginMask,
endMask,
ellipsisMask,
newAxisMask,
shrinkAxisMask
};
return ENGINE.runKernelFunc(forward, inputs, null, StridedSlice, attrs);
}
var stridedSlice = op({stridedSlice_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tan_(x) {
var $x = convertToTensor(x, "x", "tan");
var inputs = {x: $x};
return ENGINE.runKernelFunc(function(backend2, save) {
var res = backend2.tan($x);
save([$x]);
return res;
}, inputs, null, Tan);
}
var tan = op({tan_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tensor2d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 2) {
throw new Error("tensor2d() requires shape to have two numbers");
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 2 && inferredShape.length !== 1) {
throw new Error("tensor2d() requires values to be number[][] or flat/TypedArray");
}
if (inferredShape.length === 1 && shape == null) {
throw new Error("tensor2d() requires shape to be provided when `values` are a flat/TypedArray");
}
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tensor4d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 4) {
throw new Error("tensor4d() requires shape to have four numbers");
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 4 && inferredShape.length !== 1) {
throw new Error("tensor4d() requires values to be number[][][][] or flat/TypedArray");
}
if (inferredShape.length === 1 && shape == null) {
throw new Error("tensor4d() requires shape to be provided when `values` are a flat array");
}
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tensor5d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 5) {
throw new Error("tensor5d() requires shape to have five numbers");
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 5 && inferredShape.length !== 1) {
throw new Error("tensor5d() requires values to be number[][][][][] or flat/TypedArray");
}
if (inferredShape.length === 1 && shape == null) {
throw new Error("tensor5d() requires shape to be provided when `values` are a flat array");
}
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tensor6d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 6) {
throw new Error("tensor6d() requires shape to have six numbers");
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 6 && inferredShape.length !== 1) {
throw new Error("tensor6d() requires values to be number[][][][][][] or flat/TypedArray");
}
if (inferredShape.length === 1 && shape == null) {
throw new Error("tensor6d() requires shape to be provided when `values` are a flat array");
}
shape = shape || inferredShape;
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function topk_(x, k, sorted) {
if (k === void 0) {
k = 1;
}
if (sorted === void 0) {
sorted = true;
}
var $x = convertToTensor(x, "x", "topk");
if ($x.rank === 0) {
throw new Error("topk() expects the input to be of rank 1 or higher");
}
var lastDim = $x.shape[$x.shape.length - 1];
if (k > lastDim) {
throw new Error("'k' passed to topk() must be <= the last dimension (" + lastDim + ") " + ("but got " + k));
}
var inputs = {x: $x};
var attrs = {k, sorted};
var _a = ENGINE.runKernelFunc(function(b) {
return b.topk($x, k, sorted);
}, inputs, null, TopK, attrs), values = _a[0], indices = _a[1];
return {values, indices};
}
var topk = op({topk_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function truncatedNormal_(shape, mean2, stdDev, dtype, seed) {
if (mean2 === void 0) {
mean2 = 0;
}
if (stdDev === void 0) {
stdDev = 1;
}
if (dtype != null && dtype === "bool") {
throw new Error("Unsupported data type $ { dtype }");
}
var randGauss = new MPRandGauss(mean2, stdDev, dtype, true, seed);
var res = buffer(shape, dtype);
for (var i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}
var truncatedNormal = op({truncatedNormal_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unique_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, "x", "unique", null);
assert($x.rank > 0, function() {
return "The input tensor must be at least 1D";
});
var inputs = {x: $x};
var attrs = {axis};
var _a = ENGINE.runKernel(Unique, inputs, attrs), values = _a[0], indices = _a[1];
return {values, indices};
}
var unique = op({unique_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unsortedSegmentSum_(x, segmentIds, numSegments) {
var $x = convertToTensor(x, "x", "unsortedSegmentSum");
var $segmentIds = convertToTensor(segmentIds, "segmentIds", "unsortedSegmentSum", "int32");
assert(isInt(numSegments), function() {
return "numSegments must be of dtype int";
});
var inputs = {x: $x, segmentIds: $segmentIds};
var attrs = {numSegments};
var forward = function(backend2, save) {
var res = backend2.unsortedSegmentSum($x, $segmentIds, numSegments);
save([$segmentIds]);
return res;
};
return ENGINE.runKernelFunc(forward, inputs, null, UnsortedSegmentSum, attrs);
}
var unsortedSegmentSum = op({unsortedSegmentSum_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unstack_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, "x", "unstack");
assert(axis >= -$x.shape.length && axis < $x.shape.length, function() {
return "Axis = " + axis + " is not in [-" + $x.shape.length + ", " + $x.shape.length + ")";
});
if (axis < 0) {
axis += $x.shape.length;
}
var inputs = {value: $x};
var attrs = {axis};
var forward = function(backend2) {
return backend2.unstack($x, axis);
};
return ENGINE.runKernelFunc(forward, inputs, null, Unpack, attrs);
}
var unstack = op({unstack_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function variable(initialValue, trainable, name, dtype) {
if (trainable === void 0) {
trainable = true;
}
return ENGINE.makeVariable(initialValue, trainable, name, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function whereImpl(condShape, condVals) {
var indices = [];
for (var i = 0; i < condVals.length; i++) {
if (condVals[i]) {
indices.push(i);
}
}
var inBuffer = buffer(condShape, "int32");
var out = buffer([indices.length, condShape.length], "int32");
for (var i = 0; i < indices.length; i++) {
var loc = inBuffer.indexToLoc(indices[i]);
var offset = i * condShape.length;
out.values.set(loc, offset);
}
return out.toTensor();
}
function whereAsync_(condition) {
return __awaiter(this, void 0, void 0, function() {
var $condition, vals, res;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
$condition = convertToTensor(condition, "condition", "whereAsync", "bool");
return [4, $condition.data()];
case 1:
vals = _a.sent();
res = whereImpl($condition.shape, vals);
if (condition !== $condition) {
$condition.dispose();
}
return [2, res];
}
});
});
}
var whereAsync = whereAsync_;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function booleanMaskAsync_(tensor2, mask, axis) {
return __awaiter(this, void 0, void 0, function() {
var $tensor, $mask, axisFrom, maskDim, tensorShape, leadingSize, i, targetTensorShape, reshapedTensor, reshapedMask, positivePositions, indices, res;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
$tensor = convertToTensor(tensor2, "tensor", "boolMask");
$mask = convertToTensor(mask, "mask", "boolMask", "bool");
axisFrom = axis == null ? 0 : axis;
maskDim = $mask.rank;
tensorShape = $tensor.shape;
assert(maskDim > 0, function() {
return "mask cannot be scalar";
});
assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, "mask's shape must match the first K dimensions of tensor's shape,");
leadingSize = 1;
for (i = axisFrom; i < axisFrom + maskDim; i++) {
leadingSize *= tensorShape[i];
}
targetTensorShape = tensorShape.slice(0, axisFrom).concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
reshapedTensor = reshape($tensor, targetTensorShape);
reshapedMask = reshape($mask, [-1]);
return [4, whereAsync(reshapedMask)];
case 1:
positivePositions = _a.sent();
indices = squeeze(positivePositions, [1]);
res = gather(reshapedTensor, indices, axisFrom);
if (tensor2 !== $tensor) {
$tensor.dispose();
}
if (mask !== $mask) {
$mask.dispose();
}
indices.dispose();
reshapedTensor.dispose();
reshapedMask.dispose();
positivePositions.dispose();
return [2, res];
}
});
});
}
var booleanMaskAsync = booleanMaskAsync_;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function notEqualStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "notEqualStrict");
var $b = convertToTensor(b, "b", "notEqualStrict");
assertShapesMatch($a.shape, $b.shape, "Error in notEqualStrict: ");
return notEqual($a, $b);
}
function lessStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "lessStrict");
var $b = convertToTensor(b, "b", "lessStrict");
assertShapesMatch($a.shape, $b.shape, "Error in lessStrict: ");
return less($a, $b);
}
function equalStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "equalStrict");
var $b = convertToTensor(b, "b", "equalStrict");
assertShapesMatch($a.shape, $b.shape, "Error in equalStrict: ");
return equal($a, $b);
}
function lessEqualStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "lessEqualStrict");
var $b = convertToTensor(b, "b", "lessEqualStrict");
assertShapesMatch($a.shape, $b.shape, "Error in lessEqualStrict: ");
return lessEqual($a, $b);
}
function greaterStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "greaterStrict");
var $b = convertToTensor(b, "b", "greaterStrict");
assertShapesMatch($a.shape, $b.shape, "Error in greaterStrict: ");
return greater($a, $b);
}
function greaterEqualStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "greaterEqualStrict");
var $b = convertToTensor(b, "b", "greaterEqualStrict");
assertShapesMatch($a.shape, $b.shape, "Error in greaterEqualStrict: ");
return greaterEqual($a, $b);
}
var equalStrict = op({equalStrict_});
var greaterEqualStrict = op({greaterEqualStrict_});
var greaterStrict = op({greaterStrict_});
var lessEqualStrict = op({lessEqualStrict_});
var lessStrict = op({lessStrict_});
var notEqualStrict = op({notEqualStrict_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function addStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "addStrict");
var $b = convertToTensor(b, "b", "addStrict");
assertShapesMatch($a.shape, $b.shape, "Error in addStrict: ");
return add$1($a, $b);
}
function subStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "subStrict");
var $b = convertToTensor(b, "b", "subStrict");
assertShapesMatch($a.shape, $b.shape, "Error in subStrict: ");
return sub($a, $b);
}
function powStrict_(base, exp2) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
assertShapesMatch(base.shape, exp2.shape, "Error in powStrict: ");
return pow(base, exp2);
}
function mulStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "mul");
var $b = convertToTensor(b, "b", "mul");
assertShapesMatch($a.shape, $b.shape, "Error in multiplyStrict: ");
return mul($a, $b);
}
function divStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "div");
var $b = convertToTensor(b, "b", "div");
assertShapesMatch($a.shape, $b.shape, "Error in divideStrict: ");
return div($a, $b);
}
function modStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "modStrict");
var $b = convertToTensor(b, "b", "modStrict");
assertShapesMatch($a.shape, $b.shape, "Error in modStrict: ");
return mod($a, $b);
}
function minimumStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "minimumStrict");
var $b = convertToTensor(b, "b", "minimumStrict");
assertShapesMatch($a.shape, $b.shape, "Error in minimumStrict: ");
return minimum($a, $b);
}
function maximumStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "maximumStrict");
var $b = convertToTensor(b, "b", "maximumStrict");
assertShapesMatch($a.shape, $b.shape, "Error in maximumStrict: ");
return maximum($a, $b);
}
function squaredDifferenceStrict_(a, b) {
deprecationWarn("strict variants of ops have been deprecated and will be removed in future");
var $a = convertToTensor(a, "a", "squaredDifferenceStrict");
var $b = convertToTensor(b, "b", "squaredDifferenceStrict");
assertShapesMatch($a.shape, $b.shape, "Error in squaredDifferenceStrict: ");
return squaredDifference($a, $b);
}
var addStrict = op({addStrict_});
var divStrict = op({divStrict_});
var maximumStrict = op({maximumStrict_});
var minimumStrict = op({minimumStrict_});
var modStrict = op({modStrict_});
var mulStrict = op({mulStrict_});
var powStrict = op({powStrict_});
var squaredDifferenceStrict = op({squaredDifferenceStrict_});
var subStrict = op({subStrict_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function norm_(x, ord, axis, keepDims) {
if (ord === void 0) {
ord = "euclidean";
}
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
x = convertToTensor(x, "x", "norm");
var norm2 = normImpl(x, ord, axis);
var keepDimsShape = norm2.shape;
if (keepDims) {
var axes = parseAxisParam(axis, x.shape);
keepDimsShape = expandShapeToKeepDim(norm2.shape, axes);
}
return reshape(norm2, keepDimsShape);
}
function normImpl(x, p, axis) {
if (axis === void 0) {
axis = null;
}
if (x.rank === 0) {
return abs(x);
}
if (x.rank !== 1 && axis === null) {
return normImpl(reshape(x, [-1]), p, axis);
}
if (x.rank === 1 || typeof axis === "number" || Array.isArray(axis) && axis.length === 1) {
if (p === 1) {
return sum$1(abs(x), axis);
}
if (p === Infinity) {
return max(abs(x), axis);
}
if (p === -Infinity) {
return min(abs(x), axis);
}
if (p === "euclidean" || p === 2) {
return sqrt(sum$1(pow(abs(x), scalar(2, "int32")), axis));
}
throw new Error("Error in norm: invalid ord value: " + p);
}
if (Array.isArray(axis) && axis.length === 2) {
if (p === 1) {
return max(sum$1(abs(x), axis[0]), axis[1] - 1);
}
if (p === Infinity) {
return max(sum$1(abs(x), axis[1]), axis[0]);
}
if (p === -Infinity) {
return min(sum$1(abs(x), axis[1]), axis[0]);
}
if (p === "fro" || p === "euclidean") {
return sqrt(sum$1(square(x), axis));
}
throw new Error("Error in norm: invalid ord value: " + p);
}
throw new Error("Error in norm: invalid axis: " + axis);
}
var norm = op({norm_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function movingAverage_(v, x, decay, step2, zeroDebias) {
if (zeroDebias === void 0) {
zeroDebias = true;
}
var $v = convertToTensor(v, "v", "movingAverage");
var $x = convertToTensor(x, "x", "movingAverage");
var $decay = convertToTensor(decay, "decay", "movingAverage");
assertTypesMatch($v, $x);
assert(arraysEqual($v.shape, $x.shape), function() {
return "Shape mismatch in v and x";
});
var one = scalar(1);
var oneMinusDecay = sub(one, $decay);
var update = mul(sub($x, $v), oneMinusDecay);
if (zeroDebias) {
assert(step2 != null, function() {
return "When using zeroDebias: true, step is required.";
});
var $step = convertToTensor(step2, "step", "movingAverage");
update = div(update, sub(one, pow($decay, $step)));
}
return add$1($v, update);
}
var movingAverage = op({movingAverage_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function scatterND_(indices, updates, shape) {
var $indices = convertToTensor(indices, "indices", "scatterND", "int32");
var $updates = convertToTensor(updates, "updates", "scatterND");
validateInput($updates, $indices, shape);
var forward = function(backend2) {
return backend2.scatterND($indices, $updates, shape);
};
var inputs = {indices: $indices, updates: $updates};
var attrs = {shape};
return ENGINE.runKernelFunc(forward, inputs, null, ScatterNd, attrs);
}
var scatterND = op({scatterND_});
function validateInput$1(sparseIndices, sparseValues, outputShape, defaultValues) {
if (sparseIndices.dtype !== "int32") {
throw new Error("tf.sparseToDense() expects the indices to be int32 type," + (" but the dtype was " + sparseIndices.dtype + "."));
}
if (sparseIndices.rank > 2) {
throw new Error("sparseIndices should be a scalar, vector, or matrix," + (" but got shape " + sparseIndices.shape + "."));
}
var numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1;
var numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1;
if (outputShape.length !== numDims) {
throw new Error("outputShape has incorrect number of elements:," + (" " + outputShape.length + ", should be: " + numDims + "."));
}
var numValues = sparseValues.size;
if (!(sparseValues.rank === 0 || sparseValues.rank === 1 && numValues === numElems)) {
throw new Error("sparseValues has incorrect shape " + (sparseValues.shape + ", should be [] or [" + numElems + "]"));
}
if (sparseValues.dtype !== defaultValues.dtype) {
throw new Error("sparseValues.dtype must match defaultValues.dtype");
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseToDense_(sparseIndices, sparseValues, outputShape, defaultValue) {
if (defaultValue === void 0) {
defaultValue = 0;
}
var $sparseIndices = convertToTensor(sparseIndices, "sparseIndices", "sparseToDense", "int32");
var $sparseValues = convertToTensor(sparseValues, "sparseValues", "sparseToDense");
var $defaultValue = convertToTensor(defaultValue, "defaultValue", "sparseToDense", $sparseValues.dtype);
validateInput$1($sparseIndices, $sparseValues, outputShape, $defaultValue);
var inputs = {
sparseIndices: $sparseIndices,
sparseValues: $sparseValues,
defaultValue: $defaultValue
};
var attrs = {outputShape};
return ENGINE.runKernelFunc(function(backend2) {
return backend2.sparseToDense($sparseIndices, $sparseValues, outputShape, $defaultValue);
}, inputs, null, SparseToDense, attrs);
}
var sparseToDense = op({sparseToDense_});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gatherND_(x, indices) {
var $indices = convertToTensor(indices, "indices", "gatherND", "int32");
var $x = convertToTensor(x, "x", "gatherND");
var forward = function(backend2) {
return backend2.gatherND($x, $indices);
};
var inputs = {params: $x, indices: $indices};
return ENGINE.runKernelFunc(forward, inputs, null, GatherNd);
}
var gatherND = op({gatherND_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getNoiseShape(x, noiseShape) {
if (noiseShape == null) {
return x.shape.slice();
}
if (arraysEqual(x.shape, noiseShape)) {
return noiseShape;
}
if (x.shape.length === noiseShape.length) {
var newDimension = [];
for (var i = 0; i < x.shape.length; i++) {
if (noiseShape[i] == null && x.shape[i] != null) {
newDimension.push(x.shape[i]);
} else {
newDimension.push(noiseShape[i]);
}
}
return newDimension;
}
return noiseShape;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function dropout_(x, rate, noiseShape, seed) {
var $x = convertToTensor(x, "x", "dropout");
assert($x.dtype === "float32", function() {
return "x has to be a floating point tensor since it's going to be " + ("scaled, but got a " + $x.dtype + " tensor instead.");
});
assert(rate >= 0 && rate < 1, function() {
return "rate must be a float in the range [0, 1), but got " + rate + ".";
});
if (rate === 0) {
return x instanceof Tensor ? $x.clone() : $x;
}
var $noiseShape = getNoiseShape($x, noiseShape);
var keepProb = 1 - rate;
var multiplier = div(floor(add$1(randomUniform($noiseShape, 0, 1, "float32", seed), keepProb)), keepProb);
return mul($x, multiplier);
}
var dropout = op({dropout_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function enclosingPowerOfTwo(value) {
return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2))));
}
function cosineWindow(windowLength, a, b) {
var even = 1 - windowLength % 2;
var newValues = new Float32Array(windowLength);
for (var i = 0; i < windowLength; ++i) {
var cosArg = 2 * Math.PI * i / (windowLength + even - 1);
newValues[i] = a - b * Math.cos(cosArg);
}
return tensor1d(newValues, "float32");
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function inTopKAsync_(predictions, targets, k) {
if (k === void 0) {
k = 1;
}
return __awaiter(this, void 0, void 0, function() {
var $predictions, $targets, lastDim, predictionsVals, targetsVals, _a, batch, size, precision, b, offset, vals, valAndInd, i, i;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
$predictions = convertToTensor(predictions, "predictions", "inTopK");
$targets = convertToTensor(targets, "targets", "inTopK");
assert($predictions.rank > 1, function() {
return "inTopK() expects the predictions to be of rank 2 or higher, " + ("but got " + $predictions.rank);
});
assert($predictions.rank - 1 === $targets.rank, function() {
return "predictions rank should be 1 larger than targets rank, but got predictions rank " + ($predictions.rank + " and targets rank " + $targets.rank);
});
assertShapesMatch($predictions.shape.slice(0, $predictions.shape.length - 1), $targets.shape, "predictions's shape should be align with the targets' shape, except the last dimension.");
lastDim = $predictions.shape[$predictions.shape.length - 1];
assert(k > 0 && k <= lastDim, function() {
return "'k' passed to inTopK() must be > 0 && <= the predictions last " + ("dimension (" + lastDim + "), but got " + k);
});
return [4, $predictions.data()];
case 1:
predictionsVals = _b.sent();
return [4, $targets.data()];
case 2:
targetsVals = _b.sent();
_a = [predictionsVals.length / lastDim, lastDim], batch = _a[0], size = _a[1];
precision = getTypedArrayFromDType("bool", batch);
for (b = 0; b < batch; b++) {
offset = b * size;
vals = predictionsVals.subarray(offset, offset + size);
valAndInd = [];
for (i = 0; i < vals.length; i++) {
valAndInd.push({value: vals[i], index: i});
}
valAndInd.sort(function(a, b2) {
return b2.value - a.value;
});
precision[b] = 0;
for (i = 0; i < k; i++) {
if (valAndInd[i].index === targetsVals[b]) {
precision[b] = 1;
break;
}
}
}
if (predictions !== $predictions) {
$predictions.dispose();
}
if (targets !== $targets) {
$targets.dispose();
}
return [2, tensor(precision, $targets.shape, "bool")];
}
});
});
}
var inTopKAsync = inTopKAsync_;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2DBackpropFilter_(x, dy, filterShape, strides, pad2, dataFormat, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = "NHWC";
}
var x4D = x;
if (x.rank === 3) {
x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
}
var dy4D = dy;
if (dy4D.rank === 3) {
dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
assert(x4D.rank === 4, function() {
return "Error in conv2dDerFilter: input must be rank 4, but got shape " + (x4D.shape + ".");
});
assert(dy4D.rank === 4, function() {
return "Error in conv2dDerFilter: dy must be rank 4, but got shape " + (dy4D.shape + ".");
});
assert(filterShape.length === 4, function() {
return "Error in conv2dDerFilter: filterShape must be length 4, but got " + (filterShape + ".");
});
var inDepth = dataFormat === "NHWC" ? x4D.shape[3] : x4D.shape[1];
var outDepth = dataFormat === "NHWC" ? dy4D.shape[3] : dy4D.shape[1];
assert(inDepth === filterShape[2], function() {
return "Error in conv2dDerFilter: depth of input " + inDepth + ") must " + ("match input depth in filter (" + filterShape[2] + ".");
});
assert(outDepth === filterShape[3], function() {
return "Error in conv2dDerFilter: depth of dy (" + outDepth + ") must " + ("match output depth for filter (" + filterShape[3] + ").");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in conv2dDerFilter: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2) {
var dilations = 1;
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(x4D.shape, filterShape, strides, dilations, pad2, dimRoundingMode, false, $dataFormat);
return backend2.conv2dDerFilter(x4D, dy4D, convInfo);
};
var inputs = {x: x4D, dy: dy4D};
var attrs = {strides, pad: pad2, dataFormat, dimRoundingMode, filterShape};
return ENGINE.runKernelFunc(forward, inputs, null, Conv2DBackpropFilter, attrs);
}
var conv2DBackpropFilter = op({conv2DBackpropFilter_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getFusedDyActivation(dy, y, activation) {
if (activation == null || activation === "linear") {
return dy;
}
if (activation === "relu") {
return mul(dy, step(y));
}
throw new Error("Cannot compute gradient for fused activation " + activation + ".");
}
function getFusedBiasGradient(bias, dyActivation) {
var res = dyActivation;
var reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, bias.shape);
}
function applyActivation(x, activation, preluActivationWeights) {
if (activation === "linear") {
return x;
} else if (activation === "relu") {
return relu(x);
} else if (activation === "elu") {
return elu(x);
} else if (activation === "relu6") {
return relu6(x);
} else if (activation === "prelu") {
return prelu(x, preluActivationWeights);
}
throw new Error("Unknown fused activation " + activation + ".");
}
var shouldFuse = function(gradientDepth, activation) {
var gradientMode = gradientDepth > 0;
return !gradientMode || activation === "linear";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedConv2d_(_a) {
var x = _a.x, filter = _a.filter, strides = _a.strides, pad2 = _a.pad, _b = _a.dataFormat, dataFormat = _b === void 0 ? "NHWC" : _b, _c = _a.dilations, dilations = _c === void 0 ? [1, 1] : _c, dimRoundingMode = _a.dimRoundingMode, bias = _a.bias, _d = _a.activation, activation = _d === void 0 ? "linear" : _d, preluActivationWeights = _a.preluActivationWeights;
activation = activation || "linear";
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
var result = conv2d(x, filter, strides, pad2, dataFormat, dilations, dimRoundingMode);
if (bias != null) {
result = add$1(result, bias);
}
return applyActivation(result, activation, preluActivationWeights);
}
var $x = convertToTensor(x, "x", "conv2d");
var $filter = convertToTensor(filter, "filter", "conv2d");
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function() {
return "Error in fused conv2d: input must be rank 4, but got rank " + (x4D.rank + ".");
});
assert($filter.rank === 4, function() {
return "Error in fused conv2d: filter must be rank 4, but got rank " + ($filter.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in fused conv2d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
assert(x4D.shape[3] === $filter.shape[2], function() {
return "Error in conv2d: depth of input (" + x4D.shape[3] + ") must match " + ("input depth for filter " + $filter.shape[2] + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in conv2D: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
assert(dataFormat === "NHWC", function() {
return "Error in conv2d: got dataFormat of " + dataFormat + " but only NHWC is currently supported.";
});
var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad2, dimRoundingMode);
var $bias;
if (bias != null) {
$bias = convertToTensor(bias, "bias", "fused conv2d");
$bias = makeTypesMatch($bias, $x)[0];
assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
}
var $preluActivationWeights;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(preluActivationWeights, "prelu weights", "fused conv2d");
}
var grad2 = function(dy, saved) {
var _a2 = saved, $filter2 = _a2[0], x4D2 = _a2[1], y = _a2[2], $bias2 = _a2[3];
var dyActivation = getFusedDyActivation(dy, y, activation);
assert(tupleValuesAreOne(dilations), function() {
return "Error in gradient of fused conv2D: dilation rates greater than 1 " + ("are not yet supported in gradients. Got dilations '" + dilations + "'");
});
var xDer = conv2DBackpropInput(x4D2.shape, dyActivation, $filter2, strides, pad2);
var filterDer = conv2DBackpropFilter(x4D2, dyActivation, $filter2.shape, strides, pad2);
var der = [xDer, filterDer];
if ($bias2 != null) {
var biasDer = getFusedBiasGradient($bias2, dyActivation);
der.push(biasDer);
}
return der;
};
var forward = function(backend2) {
var res = backend2.fusedConv2d({
input: x4D,
filter: $filter,
convInfo,
bias: $bias,
activation,
preluActivationWeights: $preluActivationWeights
});
return res;
};
var inputs = {
x: x4D,
filter: $filter,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
var attrs = {strides, pad: pad2, dataFormat, dilations, dimRoundingMode, activation};
if (bias == null) {
var customOp = customGrad(function(x4D2, filter2, save) {
var res = ENGINE.runKernelFunc(forward, inputs, null, FusedConv2D, attrs);
save([filter2, x4D2, res]);
if (reshapedTo4D) {
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return {value: res, gradFunc: grad2};
});
return customOp(x4D, $filter);
} else {
var customOpWithBias = customGrad(function(x4D2, filter2, bias2, save) {
var res = ENGINE.runKernelFunc(forward, inputs, null, FusedConv2D, attrs);
save([filter2, x4D2, res, bias2]);
if (reshapedTo4D) {
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return {value: res, gradFunc: grad2};
});
return customOpWithBias(x4D, $filter, $bias);
}
}
var conv2d$1 = op({fusedConv2d_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad2, dilations, dimRoundingMode) {
if (dilations === void 0) {
dilations = [1, 1];
}
var x4D = x;
if (x.rank === 3) {
x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
}
var dy4D = dy;
if (dy4D.rank === 3) {
dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
var forward = function(backend2) {
var convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad2, dimRoundingMode, true);
return backend2.depthwiseConv2DDerFilter(x4D, dy4D, convInfo);
};
var inputs = {x: x4D, dy: dy4D};
var attrs = {strides, pad: pad2, dimRoundingMode, dilations, filterShape};
return ENGINE.runKernelFunc(forward, inputs, null, DepthwiseConv2dNativeBackpropFilter, attrs);
}
var depthwiseConv2dNativeBackpropFilter = op({depthwiseConv2dNativeBackpropFilter_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad2, dilations, dimRoundingMode) {
if (dilations === void 0) {
dilations = [1, 1];
}
var dy4D = dy;
var reshapedTo4D = false;
if (dy.rank === 3) {
reshapedTo4D = true;
dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
var forward = function(backend2) {
var convInfo = computeConv2DInfo(xShape, filter.shape, strides, dilations, pad2, dimRoundingMode, true);
return backend2.depthwiseConv2DDerInput(dy4D, filter, convInfo);
};
var inputs = {dy: dy4D, filter};
var attrs = {strides, pad: pad2, dimRoundingMode, dilations, inputShape: xShape};
var res = ENGINE.runKernelFunc(forward, inputs, null, DepthwiseConv2dNativeBackpropInput, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var depthwiseConv2dNativeBackpropInput = op({depthwiseConv2dNativeBackpropInput_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedDepthwiseConv2d_(_a) {
var x = _a.x, filter = _a.filter, strides = _a.strides, pad2 = _a.pad, _b = _a.dataFormat, dataFormat = _b === void 0 ? "NHWC" : _b, _c = _a.dilations, dilations = _c === void 0 ? [1, 1] : _c, dimRoundingMode = _a.dimRoundingMode, bias = _a.bias, _d = _a.activation, activation = _d === void 0 ? "linear" : _d, preluActivationWeights = _a.preluActivationWeights;
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
var result = depthwiseConv2d(x, filter, strides, pad2, dataFormat, dilations, dimRoundingMode);
if (bias != null) {
result = add$1(result, bias);
}
return applyActivation(result, activation, preluActivationWeights);
}
var $x = convertToTensor(x, "x", "depthwiseConv2d");
var $filter = convertToTensor(filter, "filter", "depthwiseConv2d");
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function() {
return "Error in fused depthwiseConv2d: input must be rank 4, but got " + ("rank " + x4D.rank + ".");
});
assert($filter.rank === 4, function() {
return "Error in fused depthwiseConv2d: filter must be rank 4, " + ("but got rank " + $filter.rank + ".");
});
assert(x4D.shape[3] === $filter.shape[2], function() {
return "Error in fused depthwiseConv2d: number of input channels " + ("(" + x4D.shape[3] + ") must match the inChannels dimension in ") + ("filter " + $filter.shape[2] + ".");
});
if (dilations == null) {
dilations = [1, 1];
}
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in fused depthwiseConv2d: Either strides or dilations must " + ("be 1. Got strides " + strides + " and dilations '" + dilations + "'");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in fused depthwiseConv2d: pad must be an integer when " + ("using dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad2, dimRoundingMode, true);
var $bias;
if (bias != null) {
$bias = convertToTensor(bias, "bias", "fused conv2d");
$bias = makeTypesMatch($bias, $x)[0];
assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
}
var $preluActivationWeights;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(preluActivationWeights, "prelu weights", "fused depthwiseConv2d");
}
var grad2 = function(dy, saved) {
assert(tupleValuesAreOne(dilations), function() {
return "Error in gradient of fused depthwiseConv2d: dilation rates greater than 1 are not yet supported. Got dilations " + ("'" + dilations + "'");
});
var $filter2 = saved[0], x4D2 = saved[1], y = saved[2], bias2 = saved[3];
var dyActivation = getFusedDyActivation(dy, y, activation);
var xDer = depthwiseConv2dNativeBackpropInput(x4D2.shape, dyActivation, $filter2, strides, pad2, dilations, dimRoundingMode);
var filterDer = depthwiseConv2dNativeBackpropFilter(x4D2, dyActivation, $filter2.shape, strides, pad2, dilations, dimRoundingMode);
if (bias2 != null) {
var biasDer = getFusedBiasGradient($bias, dyActivation);
return [xDer, filterDer, biasDer];
}
return [xDer, filterDer];
};
var forward = function(backend2) {
var res = backend2.fusedDepthwiseConv2D({
input: x4D,
filter: $filter,
convInfo,
bias: $bias,
activation,
preluActivationWeights: $preluActivationWeights
});
return res;
};
var inputs = {
x: x4D,
filter: $filter,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
var attrs = {strides, pad: pad2, dataFormat, dilations, dimRoundingMode, activation};
if (bias == null) {
var customOp = customGrad(function(x4D2, filter2, save) {
var res = ENGINE.runKernelFunc(forward, inputs, null, FusedDepthwiseConv2D, attrs);
save([filter2, x4D2, res]);
if (reshapedTo4D) {
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return {value: res, gradFunc: grad2};
});
return customOp(x4D, $filter);
} else {
var customOpWithBias = customGrad(function(x4D2, filter2, bias2, save) {
var res = ENGINE.runKernelFunc(forward, inputs, null, FusedDepthwiseConv2D, attrs);
save([filter2, x4D2, res, bias2]);
if (reshapedTo4D) {
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return {value: res, gradFunc: grad2};
});
return customOpWithBias(x4D, $filter, $bias);
}
}
var depthwiseConv2d$1 = op({fusedDepthwiseConv2d_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedMatMul_(_a) {
var _b;
var a = _a.a, b = _a.b, _c = _a.transposeA, transposeA = _c === void 0 ? false : _c, _d = _a.transposeB, transposeB = _d === void 0 ? false : _d, bias = _a.bias, _e = _a.activation, activation = _e === void 0 ? "linear" : _e, preluActivationWeights = _a.preluActivationWeights;
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
var result = matMul(a, b, transposeA, transposeB);
if (bias != null) {
result = add$1(result, bias);
}
return applyActivation(result, activation, preluActivationWeights);
}
var $a = convertToTensor(a, "a", "fused matMul");
var $b = convertToTensor(b, "b", "fused matMul");
_b = makeTypesMatch($a, $b), $a = _b[0], $b = _b[1];
var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
var outerDimsA = $a.shape.slice(0, -2);
var outerDimsB = $b.shape.slice(0, -2);
var batchDimA = sizeFromShape(outerDimsA);
var batchDimB = sizeFromShape(outerDimsB);
assert($a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, function() {
return "Error in fused matMul: inputs must have the same rank of at least " + ("2, got ranks " + $a.rank + " and " + $b.rank + ".");
});
assert(arraysEqual(outerDimsA, outerDimsB), function() {
return "Error in fused matMul: outer dimensions (" + outerDimsA + ") and (" + (outerDimsB + ") of Tensors with shapes " + $a.shape + " and ") + ($b.shape + " must match.");
});
assert(innerShapeA === innerShapeB, function() {
return "Error in fused matMul: inner shapes (" + innerShapeA + ") and (" + (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") + ($b.shape + " and transposeA=" + transposeA) + (" and transposeB=" + transposeB + " must match.");
});
var outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]);
var a3D = transposeA ? reshape($a, [batchDimA, innerShapeA, outerShapeA]) : reshape($a, [batchDimA, outerShapeA, innerShapeA]);
var b3D = transposeB ? reshape($b, [batchDimB, outerShapeB, innerShapeB]) : reshape($b, [batchDimB, innerShapeB, outerShapeB]);
var $bias;
if (bias != null) {
$bias = convertToTensor(bias, "bias", "fused matMul");
$bias = makeTypesMatch($bias, $a)[0];
assertAndGetBroadcastShape(outShape, $bias.shape);
}
var $preluActivationWeights;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(preluActivationWeights, "prelu weights", "fused matMul");
}
var grad2 = function(dy, saved) {
var a3D2 = saved[0], b3D2 = saved[1], y = saved[2], $bias2 = saved[3];
var dyActivation = getFusedDyActivation(reshape(dy, y.shape), y, activation);
var aDer;
var bDer;
if (!transposeA && !transposeB) {
aDer = matMul(dyActivation, b3D2, false, true);
bDer = matMul(a3D2, dyActivation, true, false);
} else if (!transposeA && transposeB) {
aDer = matMul(dyActivation, b3D2, false, false);
bDer = matMul(dyActivation, a3D2, true, false);
} else if (transposeA && !transposeB) {
aDer = matMul(b3D2, dyActivation, false, true);
bDer = matMul(a3D2, dyActivation, false, false);
} else {
aDer = matMul(b3D2, dyActivation, true, true);
bDer = matMul(dyActivation, a3D2, true, true);
}
if (bias != null) {
var biasDer = getFusedBiasGradient($bias2, dyActivation);
return [aDer, bDer, biasDer];
} else {
return [aDer, bDer];
}
};
var forward = function(backend2) {
var y = backend2.fusedBatchMatMul({
a: a3D,
b: b3D,
transposeA,
transposeB,
bias: $bias,
activation,
preluActivationWeights: $preluActivationWeights
});
return y;
};
var inputs = {
a: a3D,
b: b3D,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
var attrs = {transposeA, transposeB, activation};
if (bias == null) {
var customOp = customGrad(function(a3D2, b3D2, save) {
var res = ENGINE.runKernelFunc(forward, inputs, null, _FusedMatMul, attrs);
save([a3D2, b3D2, res]);
return {value: reshape(res, outShape), gradFunc: grad2};
});
return customOp(a3D, b3D);
} else {
var customOpWithBias = customGrad(function(a3D2, b3D2, $bias2, save) {
var res = ENGINE.runKernelFunc(forward, inputs, null, _FusedMatMul, attrs);
save([a3D2, b3D2, res, $bias2]);
return {value: reshape(res, outShape), gradFunc: grad2};
});
return customOpWithBias(a3D, b3D, $bias);
}
}
var matMul$1 = op({fusedMatMul_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var fused_ops = {
__proto__: null,
conv2d: conv2d$1,
depthwiseConv2d: depthwiseConv2d$1,
matMul: matMul$1
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function hammingWindow_(windowLength) {
return cosineWindow(windowLength, 0.54, 0.46);
}
var hammingWindow = op({hammingWindow_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function hannWindow_(windowLength) {
return cosineWindow(windowLength, 0.5, 0.5);
}
var hannWindow = op({hannWindow_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function frame_(signal2, frameLength, frameStep, padEnd, padValue) {
if (padEnd === void 0) {
padEnd = false;
}
if (padValue === void 0) {
padValue = 0;
}
var start = 0;
var output = [];
while (start + frameLength <= signal2.size) {
output.push(slice(signal2, start, frameLength));
start += frameStep;
}
if (padEnd) {
while (start < signal2.size) {
var padLen = start + frameLength - signal2.size;
var pad2 = concat([
slice(signal2, start, frameLength - padLen),
fill([padLen], padValue)
]);
output.push(pad2);
start += frameStep;
}
}
if (output.length === 0) {
return tensor2d([], [0, frameLength]);
}
return reshape(concat(output), [output.length, frameLength]);
}
var frame = op({frame_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stft_(signal2, frameLength, frameStep, fftLength, windowFn) {
if (windowFn === void 0) {
windowFn = hannWindow;
}
if (fftLength == null) {
fftLength = enclosingPowerOfTwo(frameLength);
}
var framedSignal = frame(signal2, frameLength, frameStep);
var windowedSignal = mul(framedSignal, windowFn(frameLength));
var output = [];
for (var i = 0; i < framedSignal.shape[0]; i++) {
output.push(rfft(slice(windowedSignal, [i, 0], [1, frameLength]), fftLength));
}
return concat(output);
}
var stft = op({stft_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cropAndResize_(image2, boxes, boxInd, cropSize, method, extrapolationValue) {
var $image = convertToTensor(image2, "image", "cropAndResize");
var $boxes = convertToTensor(boxes, "boxes", "cropAndResize", "float32");
var $boxInd = convertToTensor(boxInd, "boxInd", "cropAndResize", "int32");
method = method || "bilinear";
extrapolationValue = extrapolationValue || 0;
var numBoxes = $boxes.shape[0];
assert($image.rank === 4, function() {
return "Error in cropAndResize: image must be rank 4," + ("but got rank " + $image.rank + ".");
});
assert($boxes.rank === 2 && $boxes.shape[1] === 4, function() {
return "Error in cropAndResize: boxes must be have size [" + numBoxes + ",4] " + ("but had shape " + $boxes.shape + ".");
});
assert($boxInd.rank === 1 && $boxInd.shape[0] === numBoxes, function() {
return "Error in cropAndResize: boxInd must be have size [" + numBoxes + "] " + ("but had shape " + $boxes.shape + ".");
});
assert(cropSize.length === 2, function() {
return "Error in cropAndResize: cropSize must be of length 2, but got " + ("length " + cropSize.length + ".");
});
assert(cropSize[0] >= 1 && cropSize[1] >= 1, function() {
return "cropSize must be atleast [1,1], but was " + cropSize;
});
assert(method === "bilinear" || method === "nearest", function() {
return "method must be bilinear or nearest, but was " + method;
});
var forward = function(backend2) {
return backend2.cropAndResize($image, $boxes, $boxInd, cropSize, method, extrapolationValue);
};
var inputs = {image: $image, boxes: $boxes, boxInd: $boxInd};
var attrs = {method, extrapolationValue, cropSize};
var res = ENGINE.runKernelFunc(forward, inputs, null, CropAndResize, attrs);
return res;
}
var cropAndResize = op({cropAndResize_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function flipLeftRight_(image2) {
var $image = convertToTensor(image2, "image", "flipLeftRight", "float32");
assert($image.rank === 4, function() {
return "Error in flipLeftRight: image must be rank 4," + ("but got rank " + $image.rank + ".");
});
var inputs = {image: $image};
var res = ENGINE.runKernel(FlipLeftRight, inputs, {});
return res;
}
var flipLeftRight = op({flipLeftRight_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function rotateWithOffset_(image2, radians, fillValue, center) {
if (fillValue === void 0) {
fillValue = 0;
}
if (center === void 0) {
center = 0.5;
}
var $image = convertToTensor(image2, "image", "rotateWithOffset", "float32");
assert($image.rank === 4, function() {
return "Error in rotateWithOffset: image must be rank 4," + ("but got rank " + $image.rank + ".");
});
var inputs = {image: $image};
var attrs = {radians, fillValue, center};
var res = ENGINE.runKernel(RotateWithOffset, inputs, attrs);
return res;
}
var rotateWithOffset = op({rotateWithOffset_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function nonMaxSuppSanityCheck(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
if (iouThreshold == null) {
iouThreshold = 0.5;
}
if (scoreThreshold == null) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (softNmsSigma == null) {
softNmsSigma = 0;
}
var numBoxes = boxes.shape[0];
maxOutputSize = Math.min(maxOutputSize, numBoxes);
assert(0 <= iouThreshold && iouThreshold <= 1, function() {
return "iouThreshold must be in [0, 1], but was '" + iouThreshold + "'";
});
assert(boxes.rank === 2, function() {
return "boxes must be a 2D tensor, but was of rank '" + boxes.rank + "'";
});
assert(boxes.shape[1] === 4, function() {
return "boxes must have 4 columns, but 2nd dimension was " + boxes.shape[1];
});
assert(scores.rank === 1, function() {
return "scores must be a 1D tensor";
});
assert(scores.shape[0] === numBoxes, function() {
return "scores has incompatible shape with boxes. Expected " + numBoxes + ", " + ("but was " + scores.shape[0]);
});
assert(0 <= softNmsSigma && softNmsSigma <= 1, function() {
return "softNmsSigma must be in [0, 1], but was '" + softNmsSigma + "'";
});
return {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function nonMaxSuppression_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
var $boxes = convertToTensor(boxes, "boxes", "nonMaxSuppression");
var $scores = convertToTensor(scores, "scores", "nonMaxSuppression");
var inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
maxOutputSize = inputs.maxOutputSize;
iouThreshold = inputs.iouThreshold;
scoreThreshold = inputs.scoreThreshold;
var attrs = {maxOutputSize, iouThreshold, scoreThreshold};
return ENGINE.runKernelFunc(function(b) {
return b.nonMaxSuppression($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
}, {boxes: $boxes, scores: $scores}, null, NonMaxSuppressionV3, attrs);
}
var nonMaxSuppression = op({nonMaxSuppression_});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function binaryInsert(arr, element, comparator) {
var index = binarySearch(arr, element, comparator);
var insertionPoint = index < 0 ? -(index + 1) : index;
arr.splice(insertionPoint, 0, element);
}
function binarySearch(arr, target, comparator) {
return binarySearch_(arr, target, comparator || defaultComparator);
}
function defaultComparator(a, b) {
return a > b ? 1 : a < b ? -1 : 0;
}
function binarySearch_(arr, target, comparator) {
var left = 0;
var right = arr.length;
var middle = 0;
var found = false;
while (left < right) {
middle = left + (right - left >>> 1);
var compareResult = comparator(target, arr[middle]);
if (compareResult > 0) {
left = middle + 1;
} else {
right = middle;
found = !compareResult;
}
}
return found ? left : -left - 1;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function nonMaxSuppressionV3Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0).selectedIndices;
}
function nonMaxSuppressionV4Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0, false, padToMaxOutputSize, true);
}
function nonMaxSuppressionV5Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true);
}
function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor, padToMaxOutputSize, returnValidOutputs) {
if (returnScoresTensor === void 0) {
returnScoresTensor = false;
}
if (padToMaxOutputSize === void 0) {
padToMaxOutputSize = false;
}
if (returnValidOutputs === void 0) {
returnValidOutputs = false;
}
var candidates = [];
for (var i = 0; i < scores.length; i++) {
if (scores[i] > scoreThreshold) {
candidates.push({score: scores[i], boxIndex: i, suppressBeginIndex: 0});
}
}
candidates.sort(ascendingComparator);
var scale = softNmsSigma > 0 ? -0.5 / softNmsSigma : 0;
var selectedIndices = [];
var selectedScores = [];
while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
var candidate = candidates.pop();
var originalScore = candidate.score, boxIndex = candidate.boxIndex, suppressBeginIndex = candidate.suppressBeginIndex;
if (originalScore < scoreThreshold) {
break;
}
var ignoreCandidate = false;
for (var j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
var iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
if (iou >= iouThreshold) {
ignoreCandidate = true;
break;
}
candidate.score = candidate.score * suppressWeight(iouThreshold, scale, iou);
if (candidate.score <= scoreThreshold) {
break;
}
}
candidate.suppressBeginIndex = selectedIndices.length;
if (!ignoreCandidate) {
if (candidate.score === originalScore) {
selectedIndices.push(boxIndex);
selectedScores.push(candidate.score);
} else if (candidate.score > scoreThreshold) {
binaryInsert(candidates, candidate, ascendingComparator);
}
}
}
var validOutputs = selectedIndices.length;
var elemsToPad = maxOutputSize - validOutputs;
if (padToMaxOutputSize && elemsToPad > 0) {
selectedIndices.push.apply(selectedIndices, new Array(elemsToPad).fill(0));
selectedScores.push.apply(selectedScores, new Array(elemsToPad).fill(0));
}
var result = {selectedIndices: tensor1d(selectedIndices, "int32")};
if (returnScoresTensor) {
result["selectedScores"] = tensor1d(selectedScores, "float32");
}
if (returnValidOutputs) {
result["validOutputs"] = scalar(validOutputs, "int32");
}
return result;
}
function intersectionOverUnion(boxes, i, j) {
var iCoord = boxes.subarray(i * 4, i * 4 + 4);
var jCoord = boxes.subarray(j * 4, j * 4 + 4);
var yminI = Math.min(iCoord[0], iCoord[2]);
var xminI = Math.min(iCoord[1], iCoord[3]);
var ymaxI = Math.max(iCoord[0], iCoord[2]);
var xmaxI = Math.max(iCoord[1], iCoord[3]);
var yminJ = Math.min(jCoord[0], jCoord[2]);
var xminJ = Math.min(jCoord[1], jCoord[3]);
var ymaxJ = Math.max(jCoord[0], jCoord[2]);
var xmaxJ = Math.max(jCoord[1], jCoord[3]);
var areaI = (ymaxI - yminI) * (xmaxI - xminI);
var areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
if (areaI <= 0 || areaJ <= 0) {
return 0;
}
var intersectionYmin = Math.max(yminI, yminJ);
var intersectionXmin = Math.max(xminI, xminJ);
var intersectionYmax = Math.min(ymaxI, ymaxJ);
var intersectionXmax = Math.min(xmaxI, xmaxJ);
var intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0) * Math.max(intersectionXmax - intersectionXmin, 0);
return intersectionArea / (areaI + areaJ - intersectionArea);
}
function suppressWeight(iouThreshold, scale, iou) {
var weight = Math.exp(scale * iou * iou);
return iou <= iouThreshold ? weight : 0;
}
function ascendingComparator(c1, c2) {
return c1.score - c2.score || c1.score === c2.score && c2.boxIndex - c1.boxIndex;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function nonMaxSuppressionAsync_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
return __awaiter(this, void 0, void 0, function() {
var $boxes, $scores, inputs, boxesAndScores, boxesVals, scoresVals, res;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
$boxes = convertToTensor(boxes, "boxes", "nonMaxSuppressionAsync");
$scores = convertToTensor(scores, "scores", "nonMaxSuppressionAsync");
inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
maxOutputSize = inputs.maxOutputSize;
iouThreshold = inputs.iouThreshold;
scoreThreshold = inputs.scoreThreshold;
return [4, Promise.all([$boxes.data(), $scores.data()])];
case 1:
boxesAndScores = _a.sent();
boxesVals = boxesAndScores[0];
scoresVals = boxesAndScores[1];
res = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
if ($boxes !== boxes) {
$boxes.dispose();
}
if ($scores !== scores) {
$scores.dispose();
}
return [2, res];
}
});
});
}
var nonMaxSuppressionAsync = nonMaxSuppressionAsync_;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function nonMaxSuppressionWithScore_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (softNmsSigma === void 0) {
softNmsSigma = 0;
}
var $boxes = convertToTensor(boxes, "boxes", "nonMaxSuppression");
var $scores = convertToTensor(scores, "scores", "nonMaxSuppression");
var params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
maxOutputSize = params.maxOutputSize;
iouThreshold = params.iouThreshold;
scoreThreshold = params.scoreThreshold;
softNmsSigma = params.softNmsSigma;
var inputs = {boxes: $boxes, scores: $scores};
var attrs = {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma};
var result = ENGINE.runKernel(NonMaxSuppressionV5, inputs, attrs);
return {selectedIndices: result[0], selectedScores: result[1]};
}
var nonMaxSuppressionWithScore = op({nonMaxSuppressionWithScore_});
function nonMaxSuppressionWithScoreAsync_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (softNmsSigma === void 0) {
softNmsSigma = 0;
}
return __awaiter(this, void 0, void 0, function() {
var $boxes, $scores, params, boxesAndScores, boxesVals, scoresVals, res;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
$boxes = convertToTensor(boxes, "boxes", "nonMaxSuppressionAsync");
$scores = convertToTensor(scores, "scores", "nonMaxSuppressionAsync");
params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
maxOutputSize = params.maxOutputSize;
iouThreshold = params.iouThreshold;
scoreThreshold = params.scoreThreshold;
softNmsSigma = params.softNmsSigma;
return [4, Promise.all([$boxes.data(), $scores.data()])];
case 1:
boxesAndScores = _a.sent();
boxesVals = boxesAndScores[0];
scoresVals = boxesAndScores[1];
res = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
if ($boxes !== boxes) {
$boxes.dispose();
}
if ($scores !== scores) {
$scores.dispose();
}
return [2, res];
}
});
});
}
var nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function nonMaxSuppressionPadded_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (padToMaxOutputSize === void 0) {
padToMaxOutputSize = false;
}
var $boxes = convertToTensor(boxes, "boxes", "nonMaxSuppression");
var $scores = convertToTensor(scores, "scores", "nonMaxSuppression");
var params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null);
var $maxOutputSize = params.maxOutputSize;
var $iouThreshold = params.iouThreshold;
var $scoreThreshold = params.scoreThreshold;
var inputs = {boxes: $boxes, scores: $scores};
var attrs = {
maxOutputSize: $maxOutputSize,
iouThreshold: $iouThreshold,
scoreThreshold: $scoreThreshold,
padToMaxOutputSize
};
var result = ENGINE.runKernel(NonMaxSuppressionV4, inputs, attrs);
return {selectedIndices: result[0], validOutputs: result[1]};
}
var nonMaxSuppressionPadded = op({nonMaxSuppressionPadded_});
function nonMaxSuppressionPaddedAsync_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (padToMaxOutputSize === void 0) {
padToMaxOutputSize = false;
}
return __awaiter(this, void 0, void 0, function() {
var $boxes, $scores, params, $maxOutputSize, $iouThreshold, $scoreThreshold, _a, boxesVals, scoresVals, res;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
$boxes = convertToTensor(boxes, "boxes", "nonMaxSuppressionAsync");
$scores = convertToTensor(scores, "scores", "nonMaxSuppressionAsync");
params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null);
$maxOutputSize = params.maxOutputSize;
$iouThreshold = params.iouThreshold;
$scoreThreshold = params.scoreThreshold;
return [4, Promise.all([$boxes.data(), $scores.data()])];
case 1:
_a = _b.sent(), boxesVals = _a[0], scoresVals = _a[1];
res = nonMaxSuppressionV4Impl(boxesVals, scoresVals, $maxOutputSize, $iouThreshold, $scoreThreshold, padToMaxOutputSize);
if ($boxes !== boxes) {
$boxes.dispose();
}
if ($scores !== scores) {
$scores.dispose();
}
return [2, res];
}
});
});
}
var nonMaxSuppressionPaddedAsync = nonMaxSuppressionPaddedAsync_;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeBilinear_(images, size, alignCorners) {
if (alignCorners === void 0) {
alignCorners = false;
}
var $images = convertToTensor(images, "images", "resizeBilinear");
assert($images.rank === 3 || $images.rank === 4, function() {
return "Error in resizeBilinear: x must be rank 3 or 4, but got " + ("rank " + $images.rank + ".");
});
assert(size.length === 2, function() {
return "Error in resizeBilinear: new shape must 2D, but got shape " + (size + ".");
});
var batchImages = $images;
var reshapedTo4D = false;
if ($images.rank === 3) {
reshapedTo4D = true;
batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
}
var newHeight = size[0], newWidth = size[1];
var forward = function(backend2, save) {
save([batchImages]);
return backend2.resizeBilinear(batchImages, newHeight, newWidth, alignCorners);
};
var inputs = {images: batchImages};
var attrs = {alignCorners, size};
var res = ENGINE.runKernelFunc(forward, inputs, null, ResizeBilinear, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var resizeBilinear = op({resizeBilinear_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeNearestNeighbor_(images, size, alignCorners) {
if (alignCorners === void 0) {
alignCorners = false;
}
var $images = convertToTensor(images, "images", "resizeNearestNeighbor");
assert($images.rank === 3 || $images.rank === 4, function() {
return "Error in resizeNearestNeighbor: x must be rank 3 or 4, but got " + ("rank " + $images.rank + ".");
});
assert(size.length === 2, function() {
return "Error in resizeNearestNeighbor: new shape must 2D, but got shape " + (size + ".");
});
assert($images.dtype === "float32" || $images.dtype === "int32", function() {
return "`images` must have `int32` or `float32` as dtype";
});
var batchImages = $images;
var reshapedTo4D = false;
if ($images.rank === 3) {
reshapedTo4D = true;
batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
}
var newHeight = size[0], newWidth = size[1];
var inputs = {images: batchImages};
var attrs = {alignCorners, size};
var forward = function(backend2, save) {
save([batchImages]);
return backend2.resizeNearestNeighbor(batchImages, newHeight, newWidth, alignCorners);
};
var res = ENGINE.runKernelFunc(forward, inputs, null, ResizeNearestNeighbor, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var resizeNearestNeighbor = op({resizeNearestNeighbor_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function bandPart_(a, numLower, numUpper) {
assert(numLower % 1 === 0, function() {
return "bandPart(): numLower must be an integer, got " + numLower + ".";
});
assert(numUpper % 1 === 0, function() {
return "bandPart(): numUpper must be an integer, got " + numUpper + ".";
});
var $a = convertToTensor(a, "a", "bandPart");
assert($a.rank >= 2, function() {
return "bandPart(): Rank must be at least 2, got " + $a.rank + ".";
});
var shape = $a.shape;
var _a = $a.shape.slice(-2), M = _a[0], N = _a[1];
if (!(numLower <= M)) {
throw new Error("bandPart(): numLower (" + numLower + ")" + (" must not be greater than the number of rows (" + M + ")."));
}
if (!(numUpper <= N)) {
throw new Error("bandPart(): numUpper (" + numUpper + ")" + (" must not be greater than the number of columns (" + N + ")."));
}
if (numLower < 0) {
numLower = M;
}
if (numUpper < 0) {
numUpper = N;
}
var i = reshape(range(0, M, 1, "int32"), [-1, 1]);
var j = range(0, N, 1, "int32");
var ij = sub(i, j);
var inBand = logicalAnd(lessEqual(ij, scalar(+numLower, "int32")), greaterEqual(ij, scalar(-numUpper, "int32")));
var zero = zeros([M, N], $a.dtype);
return reshape(stack(unstack(reshape($a, [-1, M, N])).map(function(mat) {
return where(inBand, mat, zero);
})), shape);
}
var bandPart = op({bandPart_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gramSchmidt_(xs) {
var inputIsTensor2D;
if (Array.isArray(xs)) {
inputIsTensor2D = false;
assert(xs != null && xs.length > 0, function() {
return "Gram-Schmidt process: input must not be null, undefined, or empty";
});
var dim_1 = xs[0].shape[0];
var _loop_1 = function(i2) {
assert(xs[i2].shape[0] === dim_1, function() {
return "Gram-Schmidt: Non-unique lengths found in the input vectors: " + ("(" + xs[i2].shape[0] + " vs. " + dim_1 + ")");
});
};
for (var i = 1; i < xs.length; ++i) {
_loop_1(i);
}
} else {
inputIsTensor2D = true;
xs = split(xs, xs.shape[0], 0).map(function(x) {
return squeeze(x, [0]);
});
}
assert(xs.length <= xs[0].shape[0], function() {
return "Gram-Schmidt: Number of vectors (" + xs.length + ") exceeds " + ("number of dimensions (" + xs[0].shape[0] + ").");
});
var ys = [];
var xs1d = xs;
var _loop_2 = function(i2) {
ys.push(ENGINE.tidy(function() {
var x = xs1d[i2];
if (i2 > 0) {
for (var j = 0; j < i2; ++j) {
var proj = mul(sum$1(mul(ys[j], x)), ys[j]);
x = sub(x, proj);
}
}
return div(x, norm(x, "euclidean"));
}));
};
for (var i = 0; i < xs.length; ++i) {
_loop_2(i);
}
if (inputIsTensor2D) {
return stack(ys, 0);
} else {
return ys;
}
}
var gramSchmidt = op({gramSchmidt_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function qr_(x, fullMatrices) {
if (fullMatrices === void 0) {
fullMatrices = false;
}
assert(x.rank >= 2, function() {
return "qr() requires input tensor to have a rank >= 2, but got rank " + x.rank;
});
if (x.rank === 2) {
return qr2d(x, fullMatrices);
} else {
var outerDimsProd = x.shape.slice(0, x.shape.length - 2).reduce(function(value, prev) {
return value * prev;
});
var x2ds = unstack(reshape(x, [
outerDimsProd,
x.shape[x.shape.length - 2],
x.shape[x.shape.length - 1]
]), 0);
var q2ds_1 = [];
var r2ds_1 = [];
x2ds.forEach(function(x2d) {
var _a = qr2d(x2d, fullMatrices), q2d = _a[0], r2d = _a[1];
q2ds_1.push(q2d);
r2ds_1.push(r2d);
});
var q = reshape(stack(q2ds_1, 0), x.shape);
var r = reshape(stack(r2ds_1, 0), x.shape);
return [q, r];
}
}
function qr2d(x, fullMatrices) {
if (fullMatrices === void 0) {
fullMatrices = false;
}
return ENGINE.tidy(function() {
assert(x.shape.length === 2, function() {
return "qr2d() requires a 2D Tensor, but got a " + x.shape.length + "D Tensor.";
});
var m = x.shape[0];
var n = x.shape[1];
var q = eye(m);
var r = clone(x);
var one2D = tensor2d([[1]], [1, 1]);
var w = clone(one2D);
var iters = m >= n ? n : m;
var _loop_1 = function(j2) {
var _a;
var rTemp = r;
var wTemp = w;
var qTemp = q;
_a = ENGINE.tidy(function() {
var rjEnd1 = slice(r, [j2, j2], [m - j2, 1]);
var normX = norm(rjEnd1);
var rjj = slice(r, [j2, j2], [1, 1]);
var s = where(greater(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
var u1 = sub(rjj, mul(s, normX));
var wPre = div(rjEnd1, u1);
if (wPre.shape[0] === 1) {
w = clone(one2D);
} else {
w = concat([
one2D,
slice(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])
], 0);
}
var tau = neg(div(matMul(s, u1), normX));
var rjEndAll = slice(r, [j2, 0], [m - j2, n]);
var tauTimesW = mul(tau, w);
var wT = transpose(w);
if (j2 === 0) {
r = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll)));
} else {
var rTimesTau = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll)));
r = concat([slice(r, [0, 0], [j2, n]), rTimesTau], 0);
}
var tawTimesWT = transpose(tauTimesW);
var qAllJEnd = slice(q, [0, j2], [m, q.shape[1] - j2]);
if (j2 === 0) {
q = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT));
} else {
var qTimesTau = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT));
q = concat([slice(q, [0, 0], [m, j2]), qTimesTau], 1);
}
return [w, r, q];
}), w = _a[0], r = _a[1], q = _a[2];
dispose([rTemp, wTemp, qTemp]);
};
for (var j = 0; j < iters; ++j) {
_loop_1(j);
}
if (!fullMatrices && m > n) {
q = slice(q, [0, 0], [m, n]);
r = slice(r, [0, 0], [n, n]);
}
return [q, r];
});
}
var qr = op({qr_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
(function(Reduction) {
Reduction[Reduction["NONE"] = 0] = "NONE";
Reduction[Reduction["MEAN"] = 1] = "MEAN";
Reduction[Reduction["SUM"] = 2] = "SUM";
Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS";
})(exports.Reduction || (exports.Reduction = {}));
function computeWeightedLoss_(losses2, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $losses = convertToTensor(losses2, "losses", "computeWeightedLoss");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "computeWeightedLoss");
}
var weightedLoss = $weights == null ? $losses : mul($losses, $weights);
if (reduction === exports.Reduction.NONE) {
return weightedLoss;
}
if (reduction === exports.Reduction.SUM) {
return sum$1(weightedLoss);
}
if (reduction === exports.Reduction.MEAN) {
if ($weights == null) {
return mean(weightedLoss);
} else {
var broadcastFactor = $losses.size / $weights.size;
var result = div(sum$1(weightedLoss), sum$1($weights));
return broadcastFactor > 1 ? div(result, scalar(broadcastFactor)) : result;
}
}
if (reduction === exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
if ($weights == null) {
return div(sum$1(weightedLoss), scalar($losses.size));
} else {
var broadcastedWeights = mul($weights, ones$1($losses.shape));
var numNonZeros = cast(sum$1(notEqual(broadcastedWeights, scalar(0))), "float32");
return div(sum$1(weightedLoss), numNonZeros);
}
}
throw Error("Unknown reduction: " + reduction);
}
var computeWeightedLoss = op({computeWeightedLoss_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function absoluteDifference_(labels, predictions, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, "labels", "absoluteDifference");
var $predictions = convertToTensor(predictions, "predictions", "absoluteDifference");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "absoluteDifference");
}
assertShapesMatch($labels.shape, $predictions.shape, "Error in absoluteDifference: ");
var losses2 = abs(sub($labels, $predictions));
return computeWeightedLoss(losses2, $weights, reduction);
}
var absoluteDifference = op({absoluteDifference_});
function cosineDistance_(labels, predictions, axis, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, "labels", "cosineDistance");
var $predictions = convertToTensor(predictions, "predictions", "cosineDistance");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "cosineDistance");
}
assertShapesMatch($labels.shape, $predictions.shape, "Error in cosineDistance: ");
var one = scalar(1);
var losses2 = sub(one, sum$1(mul($labels, $predictions), axis, true));
return computeWeightedLoss(losses2, $weights, reduction);
}
var cosineDistance = op({cosineDistance_});
function hingeLoss_(labels, predictions, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, "labels", "hingeLoss");
var $predictions = convertToTensor(predictions, "predictions", "hingeLoss");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "hingeLoss");
}
assertShapesMatch($labels.shape, $predictions.shape, "Error in hingeLoss: ");
var one = scalar(1);
$labels = sub(mul(scalar(2), $labels), one);
var losses2 = relu(sub(one, mul($labels, $predictions)));
return computeWeightedLoss(losses2, $weights, reduction);
}
var hingeLoss = op({hingeLoss_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function huberLoss_(labels, predictions, weights, delta, reduction) {
if (delta === void 0) {
delta = 1;
}
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, "labels", "huberLoss");
var $predictions = convertToTensor(predictions, "predictions", "huberLoss");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "huberLoss");
}
assertShapesMatch($labels.shape, $predictions.shape, "Error in huberLoss: ");
var deltaScalar = scalar(delta);
var error = abs(sub($predictions, $labels));
var quadratic = minimum(error, deltaScalar);
var linear = sub(error, quadratic);
var losses2 = add$1(mul(scalar(0.5), square(quadratic)), mul(deltaScalar, linear));
return computeWeightedLoss(losses2, $weights, reduction);
}
var huberLoss = op({huberLoss_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function logLoss_(labels, predictions, weights, epsilon, reduction) {
if (epsilon === void 0) {
epsilon = 1e-7;
}
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, "labels", "logLoss");
var $predictions = convertToTensor(predictions, "predictions", "logLoss");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "logLoss");
}
assertShapesMatch($labels.shape, $predictions.shape, "Error in logLoss: ");
var one = scalar(1);
var epsilonScalar = scalar(epsilon);
var l1 = neg(mul($labels, log(add$1($predictions, epsilonScalar))));
var l2 = mul(sub(one, $labels), log(add$1(sub(one, $predictions), epsilonScalar)));
var losses2 = sub(l1, l2);
return computeWeightedLoss(losses2, $weights, reduction);
}
var logLoss = op({logLoss_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function meanSquaredError_(labels, predictions, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, "labels", "meanSquaredError");
var $predictions = convertToTensor(predictions, "predictions", "meanSquaredError");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "meanSquaredError");
}
assertShapesMatch($labels.shape, $predictions.shape, "Error in meanSquaredError: ");
var losses2 = squaredDifference($labels, $predictions);
return computeWeightedLoss(losses2, $weights, reduction);
}
var meanSquaredError = op({meanSquaredError_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sigmoidCrossEntropyWithLogits_(labels, logits) {
var $labels = convertToTensor(labels, "labels", "sigmoidCrossEntropyWithLogits");
var $logits = convertToTensor(logits, "logits", "sigmoidCrossEntropyWithLogits");
assertShapesMatch($labels.shape, $logits.shape, "Error in sigmoidCrossEntropyWithLogits: ");
var maxOutput = relu($logits);
var outputXTarget = mul($logits, $labels);
var sigmoidOutput = log1p(exp(neg(abs($logits))));
return add$1(sub(maxOutput, outputXTarget), sigmoidOutput);
}
function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing, reduction) {
if (labelSmoothing === void 0) {
labelSmoothing = 0;
}
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $multiClassLabels = convertToTensor(multiClassLabels, "multiClassLabels", "sigmoidCrossEntropy");
var $logits = convertToTensor(logits, "logits", "sigmoidCrossEntropy");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "sigmoidCrossEntropy");
}
assertShapesMatch($multiClassLabels.shape, $logits.shape, "Error in sigmoidCrossEntropy: ");
if (labelSmoothing > 0) {
var labelSmoothingScalar = scalar(labelSmoothing);
var one = scalar(1);
var half = scalar(0.5);
$multiClassLabels = add$1(mul($multiClassLabels, sub(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
}
var losses2 = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
return computeWeightedLoss(losses2, $weights, reduction);
}
var sigmoidCrossEntropy = op({sigmoidCrossEntropy_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function softmaxCrossEntropyWithLogits_(labels, logits, dim) {
if (dim === void 0) {
dim = -1;
}
if (dim === -1) {
dim = logits.rank - 1;
}
if (dim !== logits.rank - 1) {
throw Error("Softmax cross entropy along a non-last dimension is not yet " + ("supported. Labels / logits was rank " + logits.rank + " ") + ("and dim was " + dim));
}
var customOp = customGrad(function(labels2, logits2, save) {
var keepDims = true;
var lse = logSumExp(logits2, [dim], keepDims);
var logResult = sub(cast(logits2, "float32"), lse);
save([labels2, logResult]);
var costVector = neg(mul(logResult, labels2));
var value = sum$1(costVector, [dim]);
var gradFunc = function(dy, saved) {
var labels3 = saved[0], logResult2 = saved[1];
var dyShape = expandShapeToKeepDim(dy.shape, [dim]);
return [
mul(reshape(dy, dyShape), sub(cast(labels3, "float32"), exp(logResult2))),
mul(reshape(dy, dyShape), sub(exp(logResult2), cast(labels3, "float32")))
];
};
return {value, gradFunc};
});
return customOp(labels, logits);
}
function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing, reduction) {
if (labelSmoothing === void 0) {
labelSmoothing = 0;
}
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $onehotLabels = convertToTensor(onehotLabels, "onehotLabels", "softmaxCrossEntropy");
var $logits = convertToTensor(logits, "logits", "softmaxCrossEntropy");
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, "weights", "softmaxCrossEntropy");
}
assertShapesMatch($onehotLabels.shape, $logits.shape, "Error in softmaxCrossEntropy: ");
if (labelSmoothing > 0) {
var labelSmoothingScalar = scalar(labelSmoothing);
var one = scalar(1);
var numClasses = scalar($onehotLabels.shape[1]);
$onehotLabels = add$1(mul($onehotLabels, sub(one, labelSmoothingScalar)), div(labelSmoothingScalar, numClasses));
}
var losses2 = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);
return computeWeightedLoss(losses2, $weights, reduction);
}
var softmaxCrossEntropy = op({softmaxCrossEntropy_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var spectral = {
fft,
ifft,
rfft,
irfft
};
var signal = {
hammingWindow,
hannWindow,
frame,
stft
};
var image = {
flipLeftRight,
resizeNearestNeighbor,
resizeBilinear,
rotateWithOffset,
cropAndResize,
nonMaxSuppression,
nonMaxSuppressionAsync,
nonMaxSuppressionWithScore,
nonMaxSuppressionWithScoreAsync,
nonMaxSuppressionPadded,
nonMaxSuppressionPaddedAsync
};
var linalg = {
bandPart,
gramSchmidt,
qr
};
var losses = {
absoluteDifference,
computeWeightedLoss,
cosineDistance,
hingeLoss,
huberLoss,
logLoss,
meanSquaredError,
sigmoidCrossEntropy,
softmaxCrossEntropy
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Optimizer = function(_super) {
__extends(Optimizer2, _super);
function Optimizer2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Optimizer2.prototype.minimize = function(f, returnCost, varList) {
if (returnCost === void 0) {
returnCost = false;
}
var _a = this.computeGradients(f, varList), value = _a.value, grads2 = _a.grads;
if (varList != null) {
var gradArray = varList.map(function(v) {
return {name: v.name, tensor: grads2[v.name]};
});
this.applyGradients(gradArray);
} else {
this.applyGradients(grads2);
}
dispose(grads2);
if (returnCost) {
return value;
} else {
value.dispose();
return null;
}
};
Object.defineProperty(Optimizer2.prototype, "iterations", {
get: function() {
if (this.iterations_ == null) {
this.iterations_ = 0;
}
return this.iterations_;
},
enumerable: true,
configurable: true
});
Optimizer2.prototype.incrementIterations = function() {
this.iterations_ = this.iterations + 1;
};
Optimizer2.prototype.computeGradients = function(f, varList) {
return variableGrads(f, varList);
};
Optimizer2.prototype.dispose = function() {
if (this.iterations_ != null) {
dispose(this.iterations_);
}
};
Optimizer2.prototype.saveIterations = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
if (this.iterations_ == null) {
this.iterations_ = 0;
}
return [2, {
name: "iter",
tensor: scalar(this.iterations_, "int32")
}];
});
});
};
Optimizer2.prototype.getWeights = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
throw new Error("getWeights() is not implemented for this optimizer yet.");
});
});
};
Optimizer2.prototype.setWeights = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
throw new Error("setWeights() is not implemented for this optimizer class " + ("" + this.getClassName()));
});
});
};
Optimizer2.prototype.extractIterations = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
var _a;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
_a = this;
return [4, weightValues[0].tensor.data()];
case 1:
_a.iterations_ = _b.sent()[0];
return [2, weightValues.slice(1)];
}
});
});
};
return Optimizer2;
}(Serializable);
Object.defineProperty(Optimizer, Symbol.hasInstance, {
value: function(instance) {
return instance.minimize != null && instance.computeGradients != null && instance.applyGradients != null;
}
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AdadeltaOptimizer = function(_super) {
__extends(AdadeltaOptimizer2, _super);
function AdadeltaOptimizer2(learningRate, rho, epsilon) {
if (epsilon === void 0) {
epsilon = null;
}
var _this = _super.call(this) || this;
_this.learningRate = learningRate;
_this.rho = rho;
_this.epsilon = epsilon;
_this.accumulatedGrads = [];
_this.accumulatedUpdates = [];
if (epsilon == null) {
_this.epsilon = ENGINE.backend.epsilon();
}
return _this;
}
AdadeltaOptimizer2.prototype.applyGradients = function(variableGradients) {
var _this = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function(item) {
return item.name;
}) : Object.keys(variableGradients);
variableNames.forEach(function(name, i) {
var value = ENGINE.registeredVariables[name];
var trainable = false;
if (_this.accumulatedGrads[i] == null) {
_this.accumulatedGrads[i] = {
originalName: name + "/accum_grad",
variable: tidy(function() {
return zerosLike(value).variable(trainable);
})
};
}
if (_this.accumulatedUpdates[i] == null) {
_this.accumulatedUpdates[i] = {
originalName: name + "/accum_var",
variable: tidy(function() {
return zerosLike(value).variable(trainable);
})
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var accumulatedGrad = _this.accumulatedGrads[i].variable;
var accumulatedUpdate = _this.accumulatedUpdates[i].variable;
tidy(function() {
var newAccumulatedGrad = add$1(mul(accumulatedGrad, _this.rho), mul(square(gradient), 1 - _this.rho));
var updates = mul(div(sqrt(add$1(accumulatedUpdate, _this.epsilon)), sqrt(add$1(accumulatedGrad, _this.epsilon))), gradient);
var newAccumulatedUpdate = add$1(mul(accumulatedUpdate, _this.rho), mul(square(updates), 1 - _this.rho));
accumulatedGrad.assign(newAccumulatedGrad);
accumulatedUpdate.assign(newAccumulatedUpdate);
var newValue = add$1(mul(updates, -_this.learningRate), value);
value.assign(newValue);
});
});
this.incrementIterations();
};
AdadeltaOptimizer2.prototype.dispose = function() {
if (this.accumulatedUpdates != null) {
dispose(this.accumulatedGrads.map(function(v) {
return v.variable;
}));
dispose(this.accumulatedUpdates.map(function(v) {
return v.variable;
}));
}
};
AdadeltaOptimizer2.prototype.getWeights = function() {
return __awaiter(this, void 0, void 0, function() {
var variables;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
variables = this.accumulatedGrads.concat(this.accumulatedUpdates);
return [4, this.saveIterations()];
case 1:
return [2, [_a.sent()].concat(variables.map(function(v) {
return {name: v.originalName, tensor: v.variable};
}))];
}
});
});
};
AdadeltaOptimizer2.prototype.setWeights = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
var variableCount, trainable;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.extractIterations(weightValues)];
case 1:
weightValues = _a.sent();
variableCount = weightValues.length / 2;
trainable = false;
this.accumulatedGrads = weightValues.slice(0, variableCount).map(function(v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
this.accumulatedUpdates = weightValues.slice(variableCount, variableCount * 2).map(function(v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
return [2];
}
});
});
};
AdadeltaOptimizer2.prototype.getConfig = function() {
return {
learningRate: this.learningRate,
rho: this.rho,
epsilon: this.epsilon
};
};
AdadeltaOptimizer2.fromConfig = function(cls, config) {
return new cls(config["learningRate"], config["rho"], config["epsilon"]);
};
AdadeltaOptimizer2.className = "Adadelta";
return AdadeltaOptimizer2;
}(Optimizer);
registerClass(AdadeltaOptimizer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AdagradOptimizer = function(_super) {
__extends(AdagradOptimizer2, _super);
function AdagradOptimizer2(learningRate, initialAccumulatorValue) {
if (initialAccumulatorValue === void 0) {
initialAccumulatorValue = 0.1;
}
var _this = _super.call(this) || this;
_this.learningRate = learningRate;
_this.initialAccumulatorValue = initialAccumulatorValue;
_this.accumulatedGrads = [];
return _this;
}
AdagradOptimizer2.prototype.applyGradients = function(variableGradients) {
var _this = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function(item) {
return item.name;
}) : Object.keys(variableGradients);
variableNames.forEach(function(name, i) {
var value = ENGINE.registeredVariables[name];
if (_this.accumulatedGrads[i] == null) {
var trainable_1 = false;
_this.accumulatedGrads[i] = {
originalName: name + "/accumulator",
variable: tidy(function() {
return fill(value.shape, _this.initialAccumulatorValue).variable(trainable_1);
})
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var accumulatedGrad = _this.accumulatedGrads[i].variable;
tidy(function() {
var newAccumulatedGrad = add$1(accumulatedGrad, square(gradient));
accumulatedGrad.assign(newAccumulatedGrad);
var newValue = add$1(mul(div(gradient, sqrt(add$1(newAccumulatedGrad, ENGINE.backend.epsilon()))), -_this.learningRate), value);
value.assign(newValue);
});
});
this.incrementIterations();
};
AdagradOptimizer2.prototype.dispose = function() {
if (this.accumulatedGrads != null) {
dispose(this.accumulatedGrads.map(function(v) {
return v.variable;
}));
}
};
AdagradOptimizer2.prototype.getWeights = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.saveIterations()];
case 1:
return [2, [_a.sent()].concat(this.accumulatedGrads.map(function(v) {
return {name: v.originalName, tensor: v.variable};
}))];
}
});
});
};
AdagradOptimizer2.prototype.setWeights = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
var trainable;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.extractIterations(weightValues)];
case 1:
weightValues = _a.sent();
trainable = false;
this.accumulatedGrads = weightValues.map(function(v) {
return {originalName: v.name, variable: v.tensor.variable(trainable)};
});
return [2];
}
});
});
};
AdagradOptimizer2.prototype.getConfig = function() {
return {
learningRate: this.learningRate,
initialAccumulatorValue: this.initialAccumulatorValue
};
};
AdagradOptimizer2.fromConfig = function(cls, config) {
return new cls(config["learningRate"], config["initialAccumulatorValue"]);
};
AdagradOptimizer2.className = "Adagrad";
return AdagradOptimizer2;
}(Optimizer);
registerClass(AdagradOptimizer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AdamOptimizer = function(_super) {
__extends(AdamOptimizer2, _super);
function AdamOptimizer2(learningRate, beta1, beta2, epsilon) {
if (epsilon === void 0) {
epsilon = null;
}
var _this = _super.call(this) || this;
_this.learningRate = learningRate;
_this.beta1 = beta1;
_this.beta2 = beta2;
_this.epsilon = epsilon;
_this.accumulatedFirstMoment = [];
_this.accumulatedSecondMoment = [];
tidy(function() {
_this.accBeta1 = scalar(beta1).variable();
_this.accBeta2 = scalar(beta2).variable();
});
if (epsilon == null) {
_this.epsilon = ENGINE.backend.epsilon();
}
return _this;
}
AdamOptimizer2.prototype.applyGradients = function(variableGradients) {
var _this = this;
var varNames = Array.isArray(variableGradients) ? variableGradients.map(function(v) {
return v.name;
}) : Object.keys(variableGradients);
tidy(function() {
var oneMinusAccBeta1 = sub(1, _this.accBeta1);
var oneMinusAccBeta2 = sub(1, _this.accBeta2);
varNames.forEach(function(name, i) {
var value = ENGINE.registeredVariables[name];
var trainable = false;
if (_this.accumulatedFirstMoment[i] == null) {
_this.accumulatedFirstMoment[i] = {
originalName: name + "/m",
variable: tidy(function() {
return zerosLike(value).variable(trainable);
})
};
}
if (_this.accumulatedSecondMoment[i] == null) {
_this.accumulatedSecondMoment[i] = {
originalName: name + "/v",
variable: tidy(function() {
return zerosLike(value).variable(trainable);
})
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var firstMoment = _this.accumulatedFirstMoment[i].variable;
var secondMoment = _this.accumulatedSecondMoment[i].variable;
var newFirstMoment = add$1(mul(firstMoment, _this.beta1), mul(gradient, 1 - _this.beta1));
var newSecondMoment = add$1(mul(secondMoment, _this.beta2), mul(square(gradient), 1 - _this.beta2));
var biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1);
var biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2);
firstMoment.assign(newFirstMoment);
secondMoment.assign(newSecondMoment);
var newValue = add$1(mul(div(biasCorrectedFirstMoment, add$1(sqrt(biasCorrectedSecondMoment), _this.epsilon)), -_this.learningRate), value);
value.assign(newValue);
});
_this.accBeta1.assign(mul(_this.accBeta1, _this.beta1));
_this.accBeta2.assign(mul(_this.accBeta2, _this.beta2));
});
this.incrementIterations();
};
AdamOptimizer2.prototype.dispose = function() {
this.accBeta1.dispose();
this.accBeta2.dispose();
if (this.accumulatedFirstMoment != null) {
dispose(this.accumulatedFirstMoment.map(function(v) {
return v.variable;
}));
}
if (this.accumulatedSecondMoment != null) {
dispose(this.accumulatedSecondMoment.map(function(v) {
return v.variable;
}));
}
};
AdamOptimizer2.prototype.getWeights = function() {
return __awaiter(this, void 0, void 0, function() {
var variables;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
variables = this.accumulatedFirstMoment.concat(this.accumulatedSecondMoment);
return [4, this.saveIterations()];
case 1:
return [2, [_a.sent()].concat(variables.map(function(v) {
return {name: v.originalName, tensor: v.variable};
}))];
}
});
});
};
AdamOptimizer2.prototype.setWeights = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
var variableCount, trainable;
var _this = this;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.extractIterations(weightValues)];
case 1:
weightValues = _a.sent();
tidy(function() {
_this.accBeta1.assign(pow(_this.beta1, _this.iterations_ + 1));
_this.accBeta2.assign(pow(_this.beta2, _this.iterations_ + 1));
});
variableCount = weightValues.length / 2;
trainable = false;
this.accumulatedFirstMoment = weightValues.slice(0, variableCount).map(function(v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
this.accumulatedSecondMoment = weightValues.slice(variableCount, variableCount * 2).map(function(v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
return [2];
}
});
});
};
AdamOptimizer2.prototype.getConfig = function() {
return {
learningRate: this.learningRate,
beta1: this.beta1,
beta2: this.beta2,
epsilon: this.epsilon
};
};
AdamOptimizer2.fromConfig = function(cls, config) {
return new cls(config["learningRate"], config["beta1"], config["beta2"], config["epsilon"]);
};
AdamOptimizer2.className = "Adam";
return AdamOptimizer2;
}(Optimizer);
registerClass(AdamOptimizer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AdamaxOptimizer = function(_super) {
__extends(AdamaxOptimizer2, _super);
function AdamaxOptimizer2(learningRate, beta1, beta2, epsilon, decay) {
if (epsilon === void 0) {
epsilon = null;
}
if (decay === void 0) {
decay = 0;
}
var _this = _super.call(this) || this;
_this.learningRate = learningRate;
_this.beta1 = beta1;
_this.beta2 = beta2;
_this.epsilon = epsilon;
_this.decay = decay;
_this.accumulatedFirstMoment = [];
_this.accumulatedWeightedInfNorm = [];
tidy(function() {
_this.iteration = scalar(0).variable();
_this.accBeta1 = scalar(beta1).variable();
});
if (epsilon == null) {
_this.epsilon = ENGINE.backend.epsilon();
}
return _this;
}
AdamaxOptimizer2.prototype.applyGradients = function(variableGradients) {
var _this = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function(item) {
return item.name;
}) : Object.keys(variableGradients);
tidy(function() {
var oneMinusAccBeta1 = sub(1, _this.accBeta1);
var lr = div(-_this.learningRate, add$1(mul(_this.iteration, _this.decay), 1));
variableNames.forEach(function(name, i) {
var value = ENGINE.registeredVariables[name];
var trainable = false;
if (_this.accumulatedFirstMoment[i] == null) {
_this.accumulatedFirstMoment[i] = {
originalName: name + "/m",
variable: zerosLike(value).variable(trainable)
};
}
if (_this.accumulatedWeightedInfNorm[i] == null) {
_this.accumulatedWeightedInfNorm[i] = {
originalName: name + "/v",
variable: zerosLike(value).variable(trainable)
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var firstMoment = _this.accumulatedFirstMoment[i].variable;
var weightedInfNorm = _this.accumulatedWeightedInfNorm[i].variable;
var newFirstMoment = add$1(mul(firstMoment, _this.beta1), mul(gradient, 1 - _this.beta1));
var ut0 = mul(weightedInfNorm, _this.beta2);
var ut1 = abs(gradient);
var newWeightedInfNorm = maximum(ut0, ut1);
firstMoment.assign(newFirstMoment);
weightedInfNorm.assign(newWeightedInfNorm);
var newValue = add$1(mul(div(lr, oneMinusAccBeta1), div(newFirstMoment, add$1(newWeightedInfNorm, _this.epsilon))), value);
value.assign(newValue);
});
_this.iteration.assign(add$1(_this.iteration, 1));
_this.accBeta1.assign(mul(_this.accBeta1, _this.beta1));
});
this.incrementIterations();
};
AdamaxOptimizer2.prototype.dispose = function() {
this.accBeta1.dispose();
this.iteration.dispose();
if (this.accumulatedFirstMoment != null) {
dispose(this.accumulatedFirstMoment.map(function(v) {
return v.variable;
}));
}
if (this.accumulatedWeightedInfNorm != null) {
dispose(this.accumulatedWeightedInfNorm.map(function(v) {
return v.variable;
}));
}
};
AdamaxOptimizer2.prototype.getWeights = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
throw new Error("getWeights() is not implemented for Adamax yet.");
});
});
};
AdamaxOptimizer2.prototype.setWeights = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
throw new Error("setWeights() is not implemented for Adamax yet.");
});
});
};
AdamaxOptimizer2.prototype.getConfig = function() {
return {
learningRate: this.learningRate,
beta1: this.beta1,
beta2: this.beta2,
epsilon: this.epsilon,
decay: this.decay
};
};
AdamaxOptimizer2.fromConfig = function(cls, config) {
return new cls(config["learningRate"], config["beta1"], config["beta2"], config["epsilon"], config["decay"]);
};
AdamaxOptimizer2.className = "Adamax";
return AdamaxOptimizer2;
}(Optimizer);
registerClass(AdamaxOptimizer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SGDOptimizer = function(_super) {
__extends(SGDOptimizer2, _super);
function SGDOptimizer2(learningRate) {
var _this = _super.call(this) || this;
_this.learningRate = learningRate;
_this.setLearningRate(learningRate);
return _this;
}
SGDOptimizer2.prototype.applyGradients = function(variableGradients) {
var _this = this;
var varNames = Array.isArray(variableGradients) ? variableGradients.map(function(v) {
return v.name;
}) : Object.keys(variableGradients);
varNames.forEach(function(name, i) {
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var value = ENGINE.registeredVariables[name];
tidy(function() {
var newValue = add$1(mul(_this.c, gradient), value);
value.assign(newValue);
});
});
this.incrementIterations();
};
SGDOptimizer2.prototype.setLearningRate = function(learningRate) {
this.learningRate = learningRate;
if (this.c != null) {
this.c.dispose();
}
this.c = keep(scalar(-learningRate));
};
SGDOptimizer2.prototype.dispose = function() {
this.c.dispose();
};
SGDOptimizer2.prototype.getWeights = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.saveIterations()];
case 1:
return [2, [_a.sent()]];
}
});
});
};
SGDOptimizer2.prototype.setWeights = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.extractIterations(weightValues)];
case 1:
weightValues = _a.sent();
if (weightValues.length !== 0) {
throw new Error("SGD optimizer does not have settable weights.");
}
return [2];
}
});
});
};
SGDOptimizer2.prototype.getConfig = function() {
return {learningRate: this.learningRate};
};
SGDOptimizer2.fromConfig = function(cls, config) {
return new cls(config["learningRate"]);
};
SGDOptimizer2.className = "SGD";
return SGDOptimizer2;
}(Optimizer);
registerClass(SGDOptimizer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MomentumOptimizer = function(_super) {
__extends(MomentumOptimizer2, _super);
function MomentumOptimizer2(learningRate, momentum, useNesterov) {
if (useNesterov === void 0) {
useNesterov = false;
}
var _this = _super.call(this, learningRate) || this;
_this.learningRate = learningRate;
_this.momentum = momentum;
_this.useNesterov = useNesterov;
_this.accumulations = [];
_this.m = scalar(_this.momentum);
return _this;
}
MomentumOptimizer2.prototype.applyGradients = function(variableGradients) {
var _this = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function(item) {
return item.name;
}) : Object.keys(variableGradients);
variableNames.forEach(function(name, i) {
var value = ENGINE.registeredVariables[name];
if (_this.accumulations[i] == null) {
var trainable_1 = false;
_this.accumulations[i] = {
originalName: name + "/momentum",
variable: tidy(function() {
return zerosLike(value).variable(trainable_1);
})
};
}
var accumulation = _this.accumulations[i].variable;
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
tidy(function() {
var newValue;
var newAccumulation = add$1(mul(_this.m, accumulation), gradient);
if (_this.useNesterov) {
newValue = add$1(mul(_this.c, add$1(gradient, mul(newAccumulation, _this.m))), value);
} else {
newValue = add$1(mul(_this.c, newAccumulation), value);
}
accumulation.assign(newAccumulation);
value.assign(newValue);
});
});
this.incrementIterations();
};
MomentumOptimizer2.prototype.dispose = function() {
this.m.dispose();
if (this.accumulations != null) {
dispose(this.accumulations.map(function(v) {
return v.variable;
}));
}
};
MomentumOptimizer2.prototype.setMomentum = function(momentum) {
this.momentum = momentum;
};
MomentumOptimizer2.prototype.getWeights = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.saveIterations()];
case 1:
return [2, [_a.sent()].concat(this.accumulations.map(function(v) {
return {name: v.originalName, tensor: v.variable};
}))];
}
});
});
};
MomentumOptimizer2.prototype.setWeights = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
var trainable;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.extractIterations(weightValues)];
case 1:
weightValues = _a.sent();
trainable = false;
this.accumulations = weightValues.map(function(v) {
return {originalName: v.name, variable: v.tensor.variable(trainable)};
});
return [2];
}
});
});
};
MomentumOptimizer2.prototype.getConfig = function() {
return {
learningRate: this.learningRate,
momentum: this.momentum,
useNesterov: this.useNesterov
};
};
MomentumOptimizer2.fromConfig = function(cls, config) {
return new cls(config["learningRate"], config["momentum"], config["useNesterov"]);
};
MomentumOptimizer2.className = "Momentum";
return MomentumOptimizer2;
}(SGDOptimizer);
registerClass(MomentumOptimizer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var RMSPropOptimizer = function(_super) {
__extends(RMSPropOptimizer2, _super);
function RMSPropOptimizer2(learningRate, decay, momentum, epsilon, centered) {
if (decay === void 0) {
decay = 0.9;
}
if (momentum === void 0) {
momentum = 0;
}
if (epsilon === void 0) {
epsilon = null;
}
if (centered === void 0) {
centered = false;
}
var _this = _super.call(this) || this;
_this.learningRate = learningRate;
_this.decay = decay;
_this.momentum = momentum;
_this.epsilon = epsilon;
_this.accumulatedMeanSquares = [];
_this.accumulatedMoments = [];
_this.accumulatedMeanGrads = [];
_this.centered = centered;
if (epsilon == null) {
_this.epsilon = ENGINE.backend.epsilon();
}
if (learningRate == null) {
throw new Error("learningRate for RMSPropOptimizer must be defined.");
}
return _this;
}
RMSPropOptimizer2.prototype.applyGradients = function(variableGradients) {
var _this = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function(item) {
return item.name;
}) : Object.keys(variableGradients);
variableNames.forEach(function(name, i) {
var value = ENGINE.registeredVariables[name];
var trainable = false;
if (_this.accumulatedMeanSquares[i] == null) {
_this.accumulatedMeanSquares[i] = {
originalName: name + "/rms",
variable: tidy(function() {
return zerosLike(value).variable(trainable);
})
};
}
if (_this.accumulatedMoments[i] == null) {
_this.accumulatedMoments[i] = {
originalName: name + "/momentum",
variable: tidy(function() {
return zerosLike(value).variable(trainable);
})
};
}
if (_this.accumulatedMeanGrads[i] == null && _this.centered) {
_this.accumulatedMeanGrads[i] = {
originalName: name + "/mg",
variable: tidy(function() {
return zerosLike(value).variable(trainable);
})
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var accumulatedMeanSquare = _this.accumulatedMeanSquares[i].variable;
var accumulatedMoments = _this.accumulatedMoments[i].variable;
tidy(function() {
var newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, _this.decay), mul(square(gradient), 1 - _this.decay));
if (_this.centered) {
var accumulatedMeanGrad = _this.accumulatedMeanGrads[i].variable;
var newAccumulatedMeanGrad = add$1(mul(accumulatedMeanGrad, _this.decay), mul(gradient, 1 - _this.decay));
var gradContribution = div(mul(gradient, _this.learningRate), sqrt(sub(newAccumulatedMeanSquare, add$1(square(newAccumulatedMeanGrad), _this.epsilon))));
var newAccumulatedMoments = add$1(mul(accumulatedMoments, _this.momentum), gradContribution);
accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
accumulatedMoments.assign(newAccumulatedMoments);
var newValue = sub(value, newAccumulatedMoments);
value.assign(newValue);
} else {
var newAccumulatedMeanSquare_1 = add$1(mul(accumulatedMeanSquare, _this.decay), mul(square(gradient), 1 - _this.decay));
var newAccumulatedMoments = add$1(mul(accumulatedMoments, _this.momentum), div(mul(gradient, _this.learningRate), sqrt(add$1(newAccumulatedMeanSquare_1, _this.epsilon))));
accumulatedMeanSquare.assign(newAccumulatedMeanSquare_1);
accumulatedMoments.assign(newAccumulatedMoments);
var newValue = sub(value, newAccumulatedMoments);
value.assign(newValue);
}
});
});
this.incrementIterations();
};
RMSPropOptimizer2.prototype.dispose = function() {
if (this.accumulatedMeanSquares != null) {
dispose(this.accumulatedMeanSquares.map(function(v) {
return v.variable;
}));
}
if (this.accumulatedMeanGrads != null && this.centered) {
dispose(this.accumulatedMeanGrads.map(function(v) {
return v.variable;
}));
}
if (this.accumulatedMoments != null) {
dispose(this.accumulatedMoments.map(function(v) {
return v.variable;
}));
}
};
RMSPropOptimizer2.prototype.getWeights = function() {
return __awaiter(this, void 0, void 0, function() {
var variables;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
variables = this.accumulatedMeanSquares.concat(this.accumulatedMoments);
if (this.centered) {
variables.push.apply(variables, this.accumulatedMeanGrads);
}
return [4, this.saveIterations()];
case 1:
return [2, [_a.sent()].concat(variables.map(function(v) {
return {name: v.originalName, tensor: v.variable};
}))];
}
});
});
};
RMSPropOptimizer2.prototype.setWeights = function(weightValues) {
return __awaiter(this, void 0, void 0, function() {
var variableCount, trainable;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.extractIterations(weightValues)];
case 1:
weightValues = _a.sent();
variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
trainable = false;
this.accumulatedMeanSquares = weightValues.slice(0, variableCount).map(function(v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
this.accumulatedMoments = weightValues.slice(variableCount, variableCount * 2).map(function(v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
if (this.centered) {
this.accumulatedMeanGrads = weightValues.slice(variableCount * 2, variableCount * 3).map(function(v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
}
return [2];
}
});
});
};
RMSPropOptimizer2.prototype.getConfig = function() {
return {
learningRate: this.learningRate,
decay: this.decay,
momentum: this.momentum,
epsilon: this.epsilon,
centered: this.centered
};
};
RMSPropOptimizer2.fromConfig = function(cls, config) {
return new cls(config["learningRate"], config["decay"], config["momentum"], config["epsilon"], config["centered"]);
};
RMSPropOptimizer2.className = "RMSProp";
return RMSPropOptimizer2;
}(Optimizer);
registerClass(RMSPropOptimizer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var OptimizerConstructors = function() {
function OptimizerConstructors2() {
}
OptimizerConstructors2.sgd = function(learningRate) {
return new SGDOptimizer(learningRate);
};
OptimizerConstructors2.momentum = function(learningRate, momentum, useNesterov) {
if (useNesterov === void 0) {
useNesterov = false;
}
return new MomentumOptimizer(learningRate, momentum, useNesterov);
};
OptimizerConstructors2.rmsprop = function(learningRate, decay, momentum, epsilon, centered) {
if (decay === void 0) {
decay = 0.9;
}
if (momentum === void 0) {
momentum = 0;
}
if (epsilon === void 0) {
epsilon = null;
}
if (centered === void 0) {
centered = false;
}
return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
};
OptimizerConstructors2.adam = function(learningRate, beta1, beta2, epsilon) {
if (learningRate === void 0) {
learningRate = 1e-3;
}
if (beta1 === void 0) {
beta1 = 0.9;
}
if (beta2 === void 0) {
beta2 = 0.999;
}
if (epsilon === void 0) {
epsilon = null;
}
return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
};
OptimizerConstructors2.adadelta = function(learningRate, rho, epsilon) {
if (learningRate === void 0) {
learningRate = 1e-3;
}
if (rho === void 0) {
rho = 0.95;
}
if (epsilon === void 0) {
epsilon = null;
}
return new AdadeltaOptimizer(learningRate, rho, epsilon);
};
OptimizerConstructors2.adamax = function(learningRate, beta1, beta2, epsilon, decay) {
if (learningRate === void 0) {
learningRate = 2e-3;
}
if (beta1 === void 0) {
beta1 = 0.9;
}
if (beta2 === void 0) {
beta2 = 0.999;
}
if (epsilon === void 0) {
epsilon = null;
}
if (decay === void 0) {
decay = 0;
}
return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
};
OptimizerConstructors2.adagrad = function(learningRate, initialAccumulatorValue) {
if (initialAccumulatorValue === void 0) {
initialAccumulatorValue = 0.1;
}
return new AdagradOptimizer(learningRate, initialAccumulatorValue);
};
return OptimizerConstructors2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var train = {
sgd: OptimizerConstructors.sgd,
momentum: OptimizerConstructors.momentum,
adadelta: OptimizerConstructors.adadelta,
adagrad: OptimizerConstructors.adagrad,
rmsprop: OptimizerConstructors.rmsprop,
adamax: OptimizerConstructors.adamax,
adam: OptimizerConstructors.adam
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var delayCallback = function() {
if (typeof requestAnimationFrame !== "undefined") {
return requestAnimationFrame;
} else if (typeof setImmediate !== "undefined") {
return setImmediate;
}
return function(f) {
return f();
};
}();
function nextFrame() {
return new Promise(function(resolve) {
return delayCallback(function() {
return resolve();
});
});
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getImageCenter(center, imageHeight, imageWidth) {
var centerX = imageWidth * (typeof center === "number" ? center : center[0]);
var centerY = imageHeight * (typeof center === "number" ? center : center[1]);
return [centerX, centerY];
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getReshaped(inputShape, blockShape, prod2, batchToSpace) {
if (batchToSpace === void 0) {
batchToSpace = true;
}
var reshaped = [];
if (batchToSpace) {
reshaped = reshaped.concat(blockShape.slice(0));
reshaped.push(inputShape[0] / prod2);
reshaped = reshaped.concat(inputShape.slice(1));
} else {
reshaped = reshaped.concat(inputShape[0]);
var spatialLength = blockShape.length;
for (var i = 0; i < spatialLength; ++i) {
reshaped = reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
}
reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
}
return reshaped;
}
function getPermuted(reshapedRank, blockShapeRank, batchToSpace) {
if (batchToSpace === void 0) {
batchToSpace = true;
}
var permuted = [];
if (batchToSpace) {
permuted.push(blockShapeRank);
for (var i = blockShapeRank + 1; i < reshapedRank; ++i) {
if (i <= 2 * blockShapeRank) {
permuted.push(i);
permuted.push(i - (blockShapeRank + 1));
} else {
permuted.push(i);
}
}
} else {
var permutedBeforeBatch = [];
var permutedAfterBatch = [];
for (var i = 1; i < reshapedRank; ++i) {
if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
permutedAfterBatch.push(i);
} else {
permutedBeforeBatch.push(i);
}
}
permuted.push.apply(permuted, permutedBeforeBatch);
permuted.push(0);
permuted.push.apply(permuted, permutedAfterBatch);
}
return permuted;
}
function getReshapedPermuted(inputShape, blockShape, prod2, batchToSpace) {
if (batchToSpace === void 0) {
batchToSpace = true;
}
var reshapedPermuted = [];
if (batchToSpace) {
reshapedPermuted.push(inputShape[0] / prod2);
} else {
reshapedPermuted.push(inputShape[0] * prod2);
}
for (var i = 1; i < inputShape.length; ++i) {
if (i <= blockShape.length) {
if (batchToSpace) {
reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
} else {
reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
}
} else {
reshapedPermuted.push(inputShape[i]);
}
}
return reshapedPermuted;
}
function getSliceBeginCoords(crops, blockShape) {
var sliceBeginCoords = [0];
for (var i = 0; i < blockShape; ++i) {
sliceBeginCoords.push(crops[i][0]);
}
return sliceBeginCoords;
}
function getSliceSize(uncroppedShape, crops, blockShape) {
var sliceSize = uncroppedShape.slice(0, 1);
for (var i = 0; i < blockShape; ++i) {
sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
}
return sliceSize;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SELU_SCALEALPHA = 1.7580993408473768;
var SELU_SCALE = 1.0507009873554805;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ERF_P = 0.3275911;
var ERF_A1 = 0.254829592;
var ERF_A2 = -0.284496736;
var ERF_A3 = 1.421413741;
var ERF_A4 = -1.453152027;
var ERF_A5 = 1.061405429;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function warn() {
var msg = [];
for (var _i2 = 0; _i2 < arguments.length; _i2++) {
msg[_i2] = arguments[_i2];
}
if (!env().getBool("IS_TEST")) {
console.warn.apply(console, msg);
}
}
function log$1() {
var msg = [];
for (var _i2 = 0; _i2 < arguments.length; _i2++) {
msg[_i2] = arguments[_i2];
}
if (!env().getBool("IS_TEST")) {
console.log.apply(console, msg);
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function mergeRealAndImagArrays(real2, imag2) {
if (real2.length !== imag2.length) {
throw new Error("Cannot merge real and imag arrays of different lengths. real:" + (real2.length + ", imag: " + imag2.length + "."));
}
var result = new Float32Array(real2.length * 2);
for (var i = 0; i < result.length; i += 2) {
result[i] = real2[i / 2];
result[i + 1] = imag2[i / 2];
}
return result;
}
function splitRealAndImagArrays(complex2) {
var real2 = new Float32Array(complex2.length / 2);
var imag2 = new Float32Array(complex2.length / 2);
for (var i = 0; i < complex2.length; i += 2) {
real2[i / 2] = complex2[i];
imag2[i / 2] = complex2[i + 1];
}
return {real: real2, imag: imag2};
}
function complexWithEvenIndex(complex2) {
var len = Math.ceil(complex2.length / 4);
var real2 = new Float32Array(len);
var imag2 = new Float32Array(len);
for (var i = 0; i < complex2.length; i += 4) {
real2[Math.floor(i / 4)] = complex2[i];
imag2[Math.floor(i / 4)] = complex2[i + 1];
}
return {real: real2, imag: imag2};
}
function complexWithOddIndex(complex2) {
var len = Math.floor(complex2.length / 4);
var real2 = new Float32Array(len);
var imag2 = new Float32Array(len);
for (var i = 2; i < complex2.length; i += 4) {
real2[Math.floor(i / 4)] = complex2[i];
imag2[Math.floor(i / 4)] = complex2[i + 1];
}
return {real: real2, imag: imag2};
}
function getComplexWithIndex(complex2, index) {
var real2 = complex2[index * 2];
var imag2 = complex2[index * 2 + 1];
return {real: real2, imag: imag2};
}
function assignToTypedArray(data, real2, imag2, index) {
data[index * 2] = real2;
data[index * 2 + 1] = imag2;
}
function exponents(n, inverse) {
var real2 = new Float32Array(n / 2);
var imag2 = new Float32Array(n / 2);
for (var i = 0; i < Math.ceil(n / 2); i++) {
var x = (inverse ? 2 : -2) * Math.PI * (i / n);
real2[i] = Math.cos(x);
imag2[i] = Math.sin(x);
}
return {real: real2, imag: imag2};
}
function exponent(k, n, inverse) {
var x = (inverse ? 2 : -2) * Math.PI * (k / n);
var real2 = Math.cos(x);
var imag2 = Math.sin(x);
return {real: real2, imag: imag2};
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function castTensor(x, dtype, backend2) {
if (dtype === "complex64") {
if (x.dtype === "complex64") {
return x.clone();
}
var zerosTensor = zeros(x.shape);
var floatX = cast(x, "float32");
var result = backend2.complex(floatX, zerosTensor);
zerosTensor.dispose();
floatX.dispose();
return result;
}
if (!hasEncodingLoss(x.dtype, dtype)) {
return ENGINE.makeTensorFromDataId(x.dataId, x.shape, dtype);
}
if (x.dtype === "complex64") {
var real2 = backend2.real(x);
var result = cast(real2, dtype);
real2.dispose();
return result;
}
if (dtype === "int32") {
return backend2.int(x);
} else if (dtype === "bool") {
var zero = scalar(0, x.dtype);
var result = backend2.notEqual(x, zero);
zero.dispose();
return result;
} else {
throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype);
}
}
function reshapeTensor(x, shape) {
return ENGINE.makeTensorFromDataId(x.dataId, shape, x.dtype);
}
function linspaceImpl(start, stop, num) {
var step2 = (stop - start) / (num - 1);
var values = makeZerosTypedArray(num, "float32");
values[0] = start;
for (var i = 1; i < values.length; i++) {
values[i] = values[i - 1] + step2;
}
return tensor1d(values, "float32");
}
var backend_util = {
__proto__: null,
slice_util,
segment_util,
castTensor,
reshapeTensor,
linspaceImpl,
upcastType,
axesAreInnerMostDims,
combineLocations,
computeOutAndReduceShapes,
expandShapeToKeepDim,
assertAxesAreInnerMostDims,
getAxesPermutation,
getUndoAxesPermutation,
getInnerMostAxes,
getBroadcastDims,
getReductionAxes,
assertAndGetBroadcastShape,
assertParamsConsistent,
computeOutShape: computeOutShape$1,
computeDilation2DInfo,
computePool2DInfo,
computePool3DInfo,
computeConv2DInfo,
computeConv3DInfo,
computeDefaultPad,
tupleValuesAreOne,
eitherStridesOrDilationsAreOne,
convertConv2DDataFormat,
getFusedDyActivation,
getFusedBiasGradient,
applyActivation,
shouldFuse,
PARALLELIZE_THRESHOLD,
computeOptimalWindowSize,
getImageCenter,
getReshaped,
getPermuted,
getReshapedPermuted,
getSliceBeginCoords,
getSliceSize,
prepareAndValidate,
validateUpdateShape,
validateInput,
calculateShapes,
SELU_SCALEALPHA,
SELU_SCALE,
ERF_P,
ERF_A1,
ERF_A2,
ERF_A3,
ERF_A4,
ERF_A5,
warn,
log: log$1,
mergeRealAndImagArrays,
splitRealAndImagArrays,
complexWithEvenIndex,
complexWithOddIndex,
getComplexWithIndex,
assignToTypedArray,
exponents,
exponent,
prepareSplitSize
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function split$1(x, sizeSplits, axis) {
var begin = new Array(x.rank).fill(0);
var size = x.shape.slice();
return sizeSplits.map(function(s) {
var sliceSize = size.slice();
sliceSize[axis] = s;
var sliceT = slice(x, begin, sliceSize);
begin[axis] += s;
return sliceT;
});
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tile$1(xBuf, reps) {
var newShape = new Array(xBuf.rank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = xBuf.shape[i] * reps[i];
}
var result = buffer(newShape, xBuf.dtype);
for (var i = 0; i < result.values.length; ++i) {
var newLoc = result.indexToLoc(i);
var originalLoc = new Array(xBuf.rank);
for (var j = 0; j < originalLoc.length; j++) {
originalLoc[j] = newLoc[j] % xBuf.shape[j];
}
var originalIndex = xBuf.locToIndex(originalLoc);
result.values[i] = xBuf.values[originalIndex];
}
return result.toTensor();
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function topkImpl(x, xShape, xDtype, k, sorted) {
var lastDim = xShape[xShape.length - 1];
var _a = [x.length / lastDim, lastDim], batch = _a[0], size = _a[1];
var allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
var allTopKIndices = getTypedArrayFromDType("int32", batch * k);
for (var b = 0; b < batch; b++) {
var offset = b * size;
var vals = x.subarray(offset, offset + size);
var valAndInd = [];
for (var i = 0; i < vals.length; i++) {
valAndInd.push({value: vals[i], index: i});
}
valAndInd.sort(function(a, b2) {
return b2.value - a.value;
});
var outOffset = b * k;
var topKVals = allTopKVals.subarray(outOffset, outOffset + k);
var topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
for (var i = 0; i < k; i++) {
topKVals[i] = valAndInd[i].value;
topKIndices[i] = valAndInd[i].index;
}
}
var outputShape = xShape.slice();
outputShape[outputShape.length - 1] = k;
return [
tensor(allTopKVals, outputShape, xDtype),
tensor(allTopKIndices, outputShape, "int32")
];
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernel_impls = {
__proto__: null,
nonMaxSuppressionV3Impl,
nonMaxSuppressionV4Impl,
nonMaxSuppressionV5Impl,
split: split$1,
tile: tile$1,
topkImpl,
whereImpl
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var absGradConfig = {
kernelName: Abs,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(dy, step(cast(x, "float32"), -1));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acosGradConfig = {
kernelName: Acos,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {
x: function() {
var a = square(cast(x, "float32"));
var b = sqrt(sub(scalar(1), a));
return neg(div(dy, b));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acoshGradConfig = {
kernelName: Acosh,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {
x: function() {
var a = sqrt(sub(square(cast(x, "float32")), 1));
return div(dy, a);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var addGradConfig = {
kernelName: Add,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function() {
var res = dy;
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, a.shape);
};
var derB = function() {
var res = dy;
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, b.shape);
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var addNGradConfig = {
kernelName: AddN,
saveAllInputs: true,
gradFunc: function(dy, saved) {
var ders = {};
saved.forEach(function(_, i) {
ders[i] = function() {
return dy.clone();
};
});
return ders;
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var argMaxGradConfig = {
kernelName: ArgMax,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return zerosLike(x);
}};
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var argMinGradConfig = {
kernelName: ArgMin,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return zerosLike(x);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asinGradConfig = {
kernelName: Asin,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return div(dy, sqrt(sub(scalar(1), square(cast(x, "float32")))));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asinhGradConfig = {
kernelName: Asinh,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {
x: function() {
var a = sqrt(add$1(scalar(1), square(cast(x, "float32"))));
return div(dy, a);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atan2GradConfig = {
kernelName: Atan2,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function() {
var d = add$1(square(a), square(b));
var res = mul(dy, div(b, d));
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, a.shape);
};
var derB = function() {
var d = add$1(square(a), square(b));
var res = neg(mul(dy, div(a, d)));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, b.shape);
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atanGradConfig = {
kernelName: Atan,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return div(dy, add$1(square(cast(x, "float32")), 1));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atanhGradConfig = {
kernelName: Atanh,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return div(dy, sub(scalar(1), square(cast(x, "float32"))));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool3dBackprop_(dy, input, filterSize, strides, dilations, pad2, dimRoundingMode) {
if (dilations === void 0) {
dilations = [1, 1, 1];
}
var $dy = convertToTensor(dy, "dy", "avgPool3dBackprop");
var $input = convertToTensor(input, "input", "avgPool3dBackprop");
var dy5D = $dy;
var input5D = $input;
var reshapedTo5D = false;
if ($input.rank === 4) {
reshapedTo5D = true;
dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
input5D = reshape($input, [
1,
$input.shape[0],
$input.shape[1],
$input.shape[2],
$input.shape[3]
]);
}
assert(dy5D.rank === 5, function() {
return "Error in avgPool3dBackprop: dy must be rank 5 but got rank " + (dy5D.rank + ".");
});
assert(input5D.rank === 5, function() {
return "Error in avgPool3dBackprop: input must be rank 5 but got rank " + (input5D.rank + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in avgPool3dBackprop: Either strides or dilations " + ("must be 1. Got strides " + strides + " and dilations '" + dilations + "'");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in maxPool3dBackprop: pad must be an integer when " + ("using, dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2) {
var convInfo = computePool3DInfo(input5D.shape, filterSize, strides, dilations, pad2, dimRoundingMode);
return backend2.avgPool3dBackprop(dy5D, input5D, convInfo);
};
var inputs = {dy: dy5D, input: input5D};
var attrs = {filterSize, strides, dilations, pad: pad2, dimRoundingMode};
var res = ENGINE.runKernelFunc(forward, inputs, null, AvgPool3DBackprop, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var avgPool3dBackprop = op({avgPool3dBackprop_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var avgPool3DGradConfig = {
kernelName: AvgPool3D,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var _a = attrs, filterSize = _a.filterSize, strides = _a.strides, dilations = _a.dilations, pad2 = _a.pad, dimRoundingMode = _a.dimRoundingMode;
var $dilations = dilations == null ? [1, 1, 1] : dilations;
return {
x: function() {
return avgPool3dBackprop(dy, x, filterSize, strides, $dilations, pad2, dimRoundingMode);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPoolBackprop_(dy, input, filterSize, strides, pad2) {
var $dy = convertToTensor(dy, "dy", "avgPoolBackprop");
var $input = convertToTensor(input, "input", "avgPoolBackprop");
assert($input.rank === $dy.rank, function() {
return "Rank of input (" + $input.rank + ") does not match rank of dy (" + $dy.rank + ")";
});
var input4D = $input;
var dy4D = $dy;
var reshapedTo4D = false;
if ($input.rank === 3) {
reshapedTo4D = true;
input4D = reshape($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
dy4D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
}
assert(dy4D.rank === 4, function() {
return "Error in avgPoolBackprop: dy must be rank 4 but got rank " + (dy4D.rank + ".");
});
assert(input4D.rank === 4, function() {
return "Error in avgPoolBackprop: input must be rank 4 but got rank " + (input4D.rank + ".");
});
var forward = function(backend2) {
var convInfo = computePool2DInfo(input4D.shape, filterSize, strides, 1, pad2);
return backend2.avgPoolBackprop(dy4D, input4D, convInfo);
};
var inputs = {dy: dy4D, input: input4D};
var attrs = {filterSize, strides, pad: pad2};
var res = ENGINE.runKernelFunc(forward, inputs, null, AvgPoolBackprop, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var avgPoolBackprop = op({avgPoolBackprop_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var avgPoolGradConfig = {
kernelName: AvgPool,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var _a = attrs, filterSize = _a.filterSize, strides = _a.strides, pad2 = _a.pad;
return {
x: function() {
return avgPoolBackprop(dy, x, filterSize, strides, pad2);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var batchMatMulGradConfig = {
kernelName: BatchMatMul,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved, attrs) {
var a = saved[0], b = saved[1];
var _a = attrs, transposeA = _a.transposeA, transposeB = _a.transposeB;
if (!transposeA && !transposeB) {
return {
a: function() {
return matMul(dy, b, false, true);
},
b: function() {
return matMul(a, dy, true, false);
}
};
} else if (!transposeA && transposeB) {
return {
a: function() {
return matMul(dy, b, false, false);
},
b: function() {
return matMul(dy, a, true, false);
}
};
} else if (transposeA && !transposeB) {
return {
a: function() {
return matMul(b, dy, false, true);
},
b: function() {
return matMul(a, dy, false, false);
}
};
} else {
return {
a: function() {
return matMul(b, dy, true, true);
},
b: function() {
return matMul(dy, a, true, true);
}
};
}
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var batchToSpaceNDGradConfig = {
kernelName: BatchToSpaceND,
gradFunc: function(dy, saved, attrs) {
var _a = attrs, blockShape = _a.blockShape, crops = _a.crops;
return {x: function() {
return spaceToBatchND(dy, blockShape, crops);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var broadcastToGradConfig = {
kernelName: BroadcastTo,
gradFunc: function(dy, saved, attrs) {
var broadCastToAttrs = attrs;
var inputShape = broadCastToAttrs.inputShape;
var outputShape = broadCastToAttrs.shape;
var reps = Array.from(outputShape);
for (var i = inputShape.length - 1; i >= 0; i--) {
if (inputShape[i] === outputShape[i]) {
reps[i] = 1;
} else if (inputShape[i] !== 1) {
throw new Error("broadcastTo(): [" + inputShape + "] cannot be broadcast to [" + outputShape + "].");
}
}
var axes = [];
for (var i = 0; i < reps.length; i++) {
if (reps[i] > 1) {
axes.push(i);
}
}
return {x: function() {
return sum$1(dy, axes, true);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var castGradConfig = {
kernelName: Cast,
gradFunc: function(dy) {
return {x: function() {
return dy.clone();
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ceilGradConfig = {
kernelName: Ceil,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var clipByValueGradConfig = {
kernelName: ClipByValue,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var _a = attrs, clipValueMin = _a.clipValueMin, clipValueMax = _a.clipValueMax;
return {
x: function() {
return where(logicalAnd(greaterEqual(x, clipValueMin), lessEqual(x, clipValueMax)), dy, zerosLike(dy));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var concatGradConfig = {
kernelName: Concat,
saveAllInputs: true,
gradFunc: function(dy, saved, attrs) {
var shapes = saved.map(function(t) {
return t.shape;
});
var axis = attrs.axis;
var $axis = parseAxisParam(axis, saved[0].shape)[0];
var sizeSplits = shapes.map(function(s) {
return s[$axis];
});
var derTensors = split(dy, sizeSplits, $axis);
return derTensors.map(function(t) {
return function() {
return t;
};
});
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var conv2DGradConfig = {
kernelName: Conv2D,
inputsToSave: ["x", "filter"],
gradFunc: function(dy, saved, attrs) {
var _a = saved, x4D = _a[0], $filter = _a[1];
var _b = attrs, dilations = _b.dilations, strides = _b.strides, pad2 = _b.pad, dataFormat = _b.dataFormat;
assert(tupleValuesAreOne(dilations), function() {
return "Error in gradient of conv2D: dilation rates greater than 1 " + ("are not yet supported in gradients. Got dilations '" + dilations + "'");
});
return {
x: function() {
return conv2DBackpropInput(x4D.shape, dy, $filter, strides, pad2, dataFormat);
},
filter: function() {
return conv2DBackpropFilter(x4D, dy, $filter.shape, strides, pad2, dataFormat);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var conv2DBackpropInputGradConfig = {
kernelName: Conv2DBackpropInput,
inputsToSave: ["dy", "filter"],
gradFunc: function(ddx, saved, attrs) {
var _a = saved, dy = _a[0], filter = _a[1];
var _b = attrs, strides = _b.strides, pad2 = _b.pad, dataFormat = _b.dataFormat, dimRoundingMode = _b.dimRoundingMode;
return {
dy: function() {
return conv2d(ddx, filter, strides, pad2, dataFormat, 1, dimRoundingMode);
},
filter: function() {
return conv2DBackpropFilter(ddx, dy, filter.shape, strides, pad2, dataFormat, dimRoundingMode);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3DBackpropFilter_(x, dy, filterShape, strides, pad2) {
var x5D = x;
if (x.rank === 4) {
x5D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
}
var dy5D = dy;
if (dy5D.rank === 4) {
dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
}
assert(x5D.rank === 5, function() {
return "Error in conv3dDerFilter: input must be rank 5, but got shape " + (x5D.shape + ".");
});
assert(dy5D.rank === 5, function() {
return "Error in conv3dDerFilter: dy must be rank 5, but got shape " + (dy5D.shape + ".");
});
assert(filterShape.length === 5, function() {
return "Error in conv3dDerFilter: filterShape must be length 5, but got " + (filterShape + ".");
});
assert(x5D.shape[4] === filterShape[3], function() {
return "Error in conv3dDerFilter: depth of input " + x5D.shape[4] + ") must " + ("match input depth in filter (" + filterShape[3] + ".");
});
assert(dy5D.shape[4] === filterShape[4], function() {
return "Error in conv3dDerFilter: depth of dy (" + dy5D.shape[4] + ") must " + ("match output depth for filter (" + filterShape[4] + ").");
});
var forward = function(backend2) {
var dilations = 1;
var convInfo = computeConv3DInfo(x5D.shape, filterShape, strides, dilations, pad2);
return backend2.conv3dDerFilter(x5D, dy5D, convInfo);
};
var inputs = {x: x5D, dy: dy5D};
var attrs = {strides, pad: pad2, filterShape};
return ENGINE.runKernelFunc(forward, inputs, null, Conv3DBackpropFilterV2, attrs);
}
var conv3DBackpropFilter = op({conv3DBackpropFilter_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var conv3DGradConfig = {
kernelName: Conv3D,
inputsToSave: ["x", "filter"],
gradFunc: function(dy, saved, attrs) {
var _a = attrs, dilations = _a.dilations, strides = _a.strides, pad2 = _a.pad;
assert(tupleValuesAreOne(dilations), function() {
return "Error in gradient of conv3D: dilation rates greater than 1 are " + ("not yet supported in gradients. Got dilations '" + dilations + "'");
});
var x5D = saved[0], $filter = saved[1];
return {
x: function() {
return conv3DBackpropInput(x5D.shape, dy, $filter, strides, pad2);
},
filter: function() {
return conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad2);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cosGradConfig = {
kernelName: Cos,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(neg(sin(cast(x, "float32"))), dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var coshGradConfig = {
kernelName: Cosh,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(sinh(cast(x, "float32")), dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cumsumGradConfig = {
kernelName: Cumsum,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var _a = attrs, axis = _a.axis, exclusive = _a.exclusive, reverse2 = _a.reverse;
return {
x: function() {
var permutation = getAxesPermutation([axis], x.rank);
var out = cumsum(dy, axis, exclusive, !reverse2);
if (permutation != null) {
out = transpose(out, permutation);
}
return out;
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var depthwiseConv2dNativeGradConfig = {
kernelName: DepthwiseConv2dNative,
inputsToSave: ["x", "filter"],
gradFunc: function(dy, saved, attrs) {
var _a = attrs, dilations = _a.dilations, strides = _a.strides, pad2 = _a.pad, dimRoundingMode = _a.dimRoundingMode;
var $dilations = dilations == null ? [1, 1] : dilations;
assert(tupleValuesAreOne($dilations), function() {
return "Error in gradient of depthwiseConv2dNative: dilation rates greater than 1 are not yet supported. Got dilations " + ("'" + $dilations + "'");
});
var _b = saved, x = _b[0], filter = _b[1];
assert(x.rank === 4, function() {
return "Error in gradient of depthwiseConv2dNative: input must be " + ("rank 4, but got rank " + x.rank + ".");
});
assert(filter.rank === 4, function() {
return "Error in gradient of depthwiseConv2dNative: filter must be " + ("rank 4, but got rank " + filter.rank + ".");
});
assert(x.shape[3] === filter.shape[2], function() {
return "Error in gradient of depthwiseConv2d: number of input " + ("channels (" + x.shape[3] + ") must match the inChannels dimension ") + ("in filter " + filter.shape[2] + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, $dilations), function() {
return "Error in gradient of depthwiseConv2d: Either strides or " + ("dilations must be 1. Got strides " + strides + " and dilations ") + ("'" + $dilations + "'.");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in depthwiseConv2d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
return {
x: function() {
return depthwiseConv2dNativeBackpropInput(x.shape, dy, filter, strides, pad2, dilations, dimRoundingMode);
},
filter: function() {
return depthwiseConv2dNativeBackpropFilter(x, dy, filter.shape, strides, pad2, dilations, dimRoundingMode);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var dilation2dGradConfig = {
kernelName: Dilation2D,
inputsToSave: ["x", "filter"],
gradFunc: function(dy, saved, attrs) {
var _a = saved, x = _a[0], filter = _a[1];
var inputInputs = {x, filter, dy};
var filterInputs = {x, filter, dy};
return {
x: function() {
return ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs);
},
filter: function() {
return ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var divGradConfig = {
kernelName: Div,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function() {
var res = div(dy, cast(b, "float32"));
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), a.shape);
}
return res;
};
var derB = function() {
var res = mul(dy, cast(a, "float32"));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = reshape(sum$1(res, reduceAxes), b.shape);
}
var tmp = square(b);
return neg(div(res, cast(tmp, "float32")));
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var eluGradConfig = {
kernelName: Elu,
outputsToSave: [true],
gradFunc: function(dy, saved) {
var y = saved[0];
var backPropKernelFunc = function(backend2) {
return backend2.eluDer(dy, y);
};
var inputs = {dy, y};
return {
x: function() {
return ENGINE.runKernelFunc(backPropKernelFunc, inputs, null, EluGrad);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var erfGradConfig = {
kernelName: Erf,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
var a = mul(exp(neg(square(x))), 2 / Math.sqrt(Math.PI));
return {x: function() {
return mul(dy, a);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expGradConfig = {
kernelName: Exp,
outputsToSave: [true],
gradFunc: function(dy, saved) {
var y = saved[0];
return {x: function() {
return mul(dy, y);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expm1GradConfig = {
kernelName: Expm1,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(dy, exp(x));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var floorGradConfig = {
kernelName: Floor,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var floorDivGradConfig = {
kernelName: FloorDiv,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function() {
var res = div(dy, cast(b, "float32"));
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), a.shape);
}
return res;
};
var derB = function() {
var res = mul(dy, cast(a, "float32"));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = reshape(sum$1(res, reduceAxes), b.shape);
}
var tmp = square(b);
return neg(div(res, cast(tmp, "float32")));
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var fusedBatchNormGradConfig = {
kernelName: FusedBatchNorm,
inputsToSave: ["x", "mean", "variance", "scale"],
gradFunc: function(dy, saved, attrs) {
var varianceEpsilon = attrs.varianceEpsilon;
var x = saved[0], mean2 = saved[1], variance = saved[2], scale = saved[3];
var scaleValue = scale == null ? scalar(1) : scale;
var reductionAxes = getReductionAxes(mean2.shape, x.shape);
var tileShape = [];
if (mean2.rank === 1) {
for (var i = 0; i < x.shape.length - 1; ++i) {
tileShape.push(x.shape[i]);
}
tileShape.push(1);
}
var xMinusMean = sub(x, mean2);
var dyTimesScaleValue = mul(dy, scaleValue);
var oneOverSqrtVariance = rsqrt(add$1(variance, scalar(varianceEpsilon)));
var minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
var derX = function() {
if (mean2.rank === 1) {
return reshape(mul(mul(dy, tile(reshape(oneOverSqrtVariance, [1, 1, 1, mean2.shape[0]]), tileShape)), scaleValue), x.shape);
} else {
return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
}
};
var derMean = function() {
var meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
if (mean2.rank === 1) {
meanDer = sum$1(meanDer, reductionAxes);
}
return reshape(meanDer, mean2.shape);
};
var derVariance = function() {
var varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
if (mean2.rank === 1) {
varianceDer = sum$1(varianceDer, reductionAxes);
}
return reshape(varianceDer, mean2.shape);
};
var derScale = function() {
var xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
var scaleDer = mul(dy, xMinusMean2TimesRsqrt);
if (mean2.rank === 1) {
scaleDer = sum$1(scaleDer, reductionAxes);
}
return reshape(scaleDer, mean2.shape);
};
var derOffset = function() {
var offsetDer = dy;
if (mean2.rank === 1) {
offsetDer = sum$1(offsetDer, reductionAxes);
}
return reshape(offsetDer, mean2.shape);
};
return {
x: derX,
mean: derMean,
variance: derVariance,
scale: derScale,
offset: derOffset
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var gatherGradConfig = {
kernelName: GatherV2,
inputsToSave: ["x", "indices"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0], indices = saved[1];
var axis = attrs.axis;
var parsedAxis = parseAxisParam(axis, x.shape)[0];
var derX = function() {
var paramsShape = x.shape;
var indicesSize = indices.size;
var outerShape = paramsShape.slice(0, parsedAxis);
var outerDims = outerShape.length;
var innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
var innerDims = innerShape.length;
var outerAxesIndices = arrayRange(0, outerDims);
var innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims);
var valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]);
var values = reshape(dy, valuesShape);
var reshapedIndices = reshape(indices, [indicesSize]);
var transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
var valuesTranspose = transpose(values, transposeDims);
var paramsGrad = unsortedSegmentSum(valuesTranspose, reshapedIndices, x.shape[parsedAxis]);
var invertTransposeDims = getUndoAxesPermutation(transposeDims);
paramsGrad = transpose(paramsGrad, invertTransposeDims);
return paramsGrad;
};
return {x: derX, indices: function() {
return indices;
}};
}
};
function arrayRange(start, stop) {
var result = [];
for (var i = start; i < stop; ++i) {
result.push(i);
}
return result;
}
function arrayConcat(arrays) {
var result = [];
for (var i = 0; i < arrays.length; ++i) {
for (var j = 0; j < arrays[i].length; ++j) {
result.push(arrays[i][j]);
}
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var greaterEqualGradConfig = {
kernelName: GreaterEqual,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
return {a: function() {
return zerosLike(a);
}, b: function() {
return zerosLike(b);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var identityGradConfig = {
kernelName: Identity,
gradFunc: function(dy) {
return {x: function() {
return cast(dy, "float32");
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isFiniteGradConfig = {
kernelName: IsFinite,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isInfGradConfig = {
kernelName: IsInf,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isNanGradConfig = {
kernelName: IsNan,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var log1pGradConfig = {
kernelName: Log1p,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return div(dy, add$1(x, 1));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logGradConfig = {
kernelName: Log,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return div(dy, cast(x, "float32"));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logSoftmaxGradConfig = {
kernelName: LogSoftmax,
inputsToSave: [],
outputsToSave: [true],
gradFunc: function(dy, saved, attrs) {
var value = saved[0];
var axis = attrs.axis;
return {
logits: function() {
var keepDims = true;
var softmax2 = exp(value);
return sub(dy, mul(sum$1(dy, axis, keepDims), softmax2));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function localResponseNormalizationBackprop_(x, y, dy, depthRadius, bias, alpha, beta) {
if (depthRadius === void 0) {
depthRadius = 5;
}
if (bias === void 0) {
bias = 1;
}
if (alpha === void 0) {
alpha = 1;
}
if (beta === void 0) {
beta = 0.5;
}
var forward = function(backend2) {
return backend2.LRNGrad(dy, x, y, depthRadius, bias, alpha, beta);
};
var inputs = {x, y, dy};
var attrs = {depthRadius, bias, alpha, beta};
return ENGINE.runKernelFunc(forward, inputs, null, LRNBackprop, attrs);
}
var localResponseNormalizationBackprop = op({localResponseNormalizationBackprop_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var lrnGradConfig = {
kernelName: LRN,
inputsToSave: ["x"],
outputsToSave: [true],
gradFunc: function(dy, saved, attrs) {
var _a = saved, x = _a[0], y = _a[1];
var _b = attrs, depthRadius = _b.depthRadius, bias = _b.bias, alpha = _b.alpha, beta = _b.beta;
return {
x: function() {
return localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gradForMinAndMax(dy, y, xOrig, origAxes) {
if (y.rank < xOrig.rank) {
y = reshape(y, expandShapeToKeepDim(y.shape, origAxes));
}
if (dy.rank < xOrig.rank) {
dy = reshape(dy, expandShapeToKeepDim(dy.shape, origAxes));
}
return {
x: function() {
var dx = mul(dy, cast(equal(xOrig, y), dy.dtype));
return dx;
}
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxGradConfig = {
kernelName: Max,
inputsToSave: ["x"],
outputsToSave: [true],
gradFunc: function(dy, saved, attrs) {
var maxAttrs = attrs;
var reductionIndices = maxAttrs.reductionIndices;
var x = saved[0];
var y = saved[1];
var origAxes = parseAxisParam(reductionIndices, x.shape);
var maxGrad = gradForMinAndMax(dy, y, x, origAxes);
return {
x: function() {
return maxGrad["x"]();
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maximumGradConfig = {
kernelName: Maximum,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var derA = function() {
return mul(dy, cast(greaterEqual(a, b), "float32"));
};
var derB = function() {
return mul(dy, cast(less(a, b), "float32"));
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool3dBackprop_(dy, input, output, filterSize, strides, dilations, pad2, dimRoundingMode) {
if (dilations === void 0) {
dilations = [1, 1, 1];
}
var $dy = convertToTensor(dy, "dy", "maxPool3dBackprop");
var $input = convertToTensor(input, "input", "maxPool3dBackprop");
var $output = convertToTensor(output, "output", "maxPool3dBackprop");
var dy5D = $dy;
var input5D = $input;
var output5D = $output;
var reshapedTo5D = false;
if ($input.rank === 4) {
reshapedTo5D = true;
dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
input5D = reshape($input, [
1,
$input.shape[0],
$input.shape[1],
$input.shape[2],
$input.shape[3]
]);
output5D = reshape($output, [
1,
$output.shape[0],
$output.shape[1],
$output.shape[2],
$output.shape[3]
]);
}
assert(dy5D.rank === 5, function() {
return "Error in maxPool3dBackprop: dy must be rank 5 but got rank " + (dy5D.rank + ".");
});
assert(input5D.rank === 5, function() {
return "Error in maxPool3dBackprop: input must be rank 5 but got rank " + (input5D.rank + ".");
});
assert(output5D.rank === 5, function() {
return "Error in maxPool3dBackprop: output must be rank 5 but got rank " + (output5D.rank + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in maxPool3dBackprop: Either strides or dilations " + ("must be 1. Got strides " + strides + " and dilations '" + dilations + "'");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in maxPool3dBackprop: pad must be an integer when " + ("using, dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2) {
var convInfo = computePool3DInfo(input5D.shape, filterSize, strides, dilations, pad2, dimRoundingMode);
return backend2.maxPool3dBackprop(dy5D, input5D, output5D, convInfo);
};
var inputs = {dy: dy5D, input: input5D, output: output5D};
var attrs = {filterSize, strides, dilations, pad: pad2, dimRoundingMode};
var res = ENGINE.runKernelFunc(forward, inputs, null, MaxPool3DBackprop, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var maxPool3dBackprop = op({maxPool3dBackprop_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxPool3DGradConfig = {
kernelName: MaxPool3D,
inputsToSave: ["x"],
outputsToSave: [true],
gradFunc: function(dy, saved, attrs) {
var _a = saved, x = _a[0], y = _a[1];
var _b = attrs, filterSize = _b.filterSize, strides = _b.strides, dilations = _b.dilations, pad2 = _b.pad, dimRoundingMode = _b.dimRoundingMode;
var $dilations = dilations == null ? [1, 1, 1] : dilations;
return {
x: function() {
return maxPool3dBackprop(dy, x, y, filterSize, strides, $dilations, pad2, dimRoundingMode);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolBackprop_(dy, input, output, filterSize, strides, pad2, dimRoundingMode) {
var $dy = convertToTensor(dy, "dy", "maxPoolBackprop");
var $input = convertToTensor(input, "input", "maxPoolBackprop");
var $output = convertToTensor(output, "output", "maxPoolBackprop");
assert($input.rank === $dy.rank, function() {
return "Rank of input (" + $input.rank + ") does not match rank of dy " + ("(" + $dy.rank + ")");
});
assert($dy.rank === 4, function() {
return "Error in maxPoolBackprop: dy must be rank 4 but got rank " + ($dy.rank + ".");
});
assert($input.rank === 4, function() {
return "Error in maxPoolBackprop: input must be rank 4 but got rank " + ($input.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad2), function() {
return "Error in maxPoolBackprop: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad2 + ".");
});
}
var forward = function(backend2) {
var convInfo = computePool2DInfo($input.shape, filterSize, strides, 1, pad2, dimRoundingMode);
return backend2.maxPoolBackprop($dy, $input, $output, convInfo);
};
var inputs = {dy: $dy, input: $input, output: $output};
var attrs = {filterSize, strides, pad: pad2, dimRoundingMode};
return ENGINE.runKernelFunc(forward, inputs, null, MaxPoolBackprop, attrs);
}
var maxPoolBackprop = op({maxPoolBackprop_});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxPoolGradConfig = {
kernelName: MaxPool,
inputsToSave: ["x"],
outputsToSave: [true],
gradFunc: function(dy, saved, attrs) {
var _a = saved, x = _a[0], y = _a[1];
var _b = attrs, filterSize = _b.filterSize, strides = _b.strides, pad2 = _b.pad;
return {
x: function() {
return maxPoolBackprop(dy, x, y, filterSize, strides, pad2);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var minGradConfig = {
kernelName: Min,
inputsToSave: ["x"],
outputsToSave: [true],
gradFunc: function(dy, saved, attrs) {
var minAttrs = attrs;
var axis = minAttrs.axis;
var x = saved[0], y = saved[1];
var origAxes = parseAxisParam(axis, x.shape);
var minGrad = gradForMinAndMax(dy, y, x, origAxes);
return {
x: function() {
return minGrad["x"]();
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var minimumGradConfig = {
kernelName: Minimum,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var derA = function() {
return mul(dy, cast(lessEqual(a, b), "float32"));
};
var derB = function() {
return mul(dy, cast(greater(a, b), "float32"));
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var mirrorPadGradConfig = {
kernelName: MirrorPad,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var paddings = attrs.paddings;
var begin = paddings.map(function(p) {
return p[0];
});
return {x: function() {
return slice(dy, begin, x.shape);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var modGradConfig = {
kernelName: Mod,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function() {
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(dy, reduceAxes), a.shape);
}
return dy;
};
var derB = function() {
var res = mul(dy, neg(floor(div(a, b))));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), b.shape);
}
return res;
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var multiplyGradConfig = {
kernelName: Multiply,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function() {
var res = mul(dy, cast(b, "float32"));
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), a.shape);
}
return res;
};
var derB = function() {
var res = mul(dy, cast(a, "float32"));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), b.shape);
}
return res;
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var negateGradConfig = {
kernelName: Negate,
gradFunc: function(dy) {
return {x: function() {
return neg(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var oneHotGradConfig = {
kernelName: OneHot,
inputsToSave: ["indices"],
gradFunc: function(dy, saved) {
var indices = saved[0];
return {indices: function() {
return zeros(indices.shape, "float32");
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var onesLikeGradConfig = {
kernelName: OnesLike,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var padV2GradConfig = {
kernelName: PadV2,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var paddings = attrs.paddings;
var begin = paddings.map(function(p) {
return p[0];
});
return {x: function() {
return slice(dy, begin, x.shape);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var powGradConfig = {
kernelName: Pow,
inputsToSave: ["a", "b"],
outputsToSave: [true],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1], y = saved[2];
var base = a;
var exp2 = b;
var outShape = assertAndGetBroadcastShape(base.shape, exp2.shape);
var derBase = function() {
var expFloat = cast(exp2, "float32");
var res = mul(dy, mul(expFloat, pow(base, sub(expFloat, scalar(1)))));
var reduceAxes = getReductionAxes(base.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, base.shape);
};
var derExp = function() {
var condition = greater(base, 0);
var logBase = where(condition, log(base), zerosLike(base));
var res = mul(dy, mul(y, logBase));
var reduceAxes = getReductionAxes(exp2.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, exp2.shape);
};
return {a: derBase, b: derExp};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var preluGradConfig = {
kernelName: Prelu,
inputsToSave: ["x", "alpha"],
gradFunc: function(dy, saved) {
var x = saved[0], alpha = saved[1];
var mask = greater(x, 0);
return {
x: function() {
return where(mask, dy, mul(dy, alpha));
},
alpha: function() {
var res = where(mask, zerosLike(dy), mul(dy, x));
var reduceAxes = getReductionAxes(alpha.shape, dy.shape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, alpha.shape);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reciprocalGradConfig = {
kernelName: Reciprocal,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return div(dy, neg(square(x)));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var relu6GradConfig = {
kernelName: Relu6,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
var mask = mul(lessEqual(x, 6), step(x));
return {x: function() {
return mul(dy, cast(mask, "float32"));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reluGradConfig = {
kernelName: Relu,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(dy, cast(step(x), "float32"));
}};
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reshapeGradConfig = {
kernelName: Reshape,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return reshape(dy, x.shape);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var resizeBilinearGradConfig = {
kernelName: ResizeBilinear,
inputsToSave: ["images"],
gradFunc: function(dy, saved, attrs) {
var images = saved[0];
var backPropKernelFunc = function(backend2) {
var alignCorners = attrs.alignCorners;
return backend2.resizeBilinearBackprop(dy, images, alignCorners);
};
var inputs = {images};
var imagesDer = function() {
return ENGINE.runKernelFunc(backPropKernelFunc, inputs, null, ResizeBilinearGrad, attrs);
};
return {images: imagesDer};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var resizeNearestNeighborGradConfig = {
kernelName: ResizeNearestNeighbor,
inputsToSave: ["images"],
gradFunc: function(dy, saved, attrs) {
var images = saved[0];
var backPropKernelFunc = function(backend2) {
var alignCorners = attrs.alignCorners;
return backend2.resizeNearestNeighborBackprop(dy, images, alignCorners);
};
var inputs = {images};
var imagesDer = function() {
return ENGINE.runKernelFunc(backPropKernelFunc, inputs, null, ResizeNearestNeighborGrad, attrs);
};
return {images: imagesDer};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reverseGradConfig = {
kernelName: Reverse,
gradFunc: function(dy, saved, attrs) {
var dims = attrs.dims;
var axes = parseAxisParam(dims, dy.shape);
return {x: function() {
return reverse(dy, axes);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var roundGradConfig = {
kernelName: Round,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var rsqrtGradConfig = {
kernelName: Rsqrt,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return neg(div(dy, mul(pow(x, 1.5), 2)));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var selectV2PoolGradConfig = {
kernelName: SelectV2,
inputsToSave: ["condition"],
gradFunc: function(dy, saved) {
var condition = saved[0];
return {
condition: function() {
return cast(zerosLike(condition), "float32");
},
t: function() {
return mul(dy, cast(condition, dy.dtype));
},
e: function() {
return mul(dy, cast(logicalNot(condition), dy.dtype));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var seluGradConfig = {
kernelName: Selu,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {
x: function() {
var mask = greater(x, scalar(0));
var scaleAlpha = scalar(SELU_SCALEALPHA);
var scale = scalar(SELU_SCALE);
var greaterThanZeroDer = mul(dy, scale);
var lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp(cast(x, "float32")));
return where(mask, greaterThanZeroDer, lessEqualZeroDer);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sigmoidGradConfig = {
kernelName: Sigmoid,
outputsToSave: [true],
gradFunc: function(dy, saved) {
var y = saved[0];
return {x: function() {
return mul(dy, mul(y, sub(scalar(1), y)));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var signGradConfig = {
kernelName: Sign,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sinGradConfig = {
kernelName: Sin,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(cos(cast(x, "float32")), dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sinhGradConfig = {
kernelName: Sinh,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(cosh(cast(x, "float32")), dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sliceGradConfig = {
kernelName: Slice,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var _a = attrs, begin = _a.begin, size = _a.size;
var inputShape = x.shape;
var _b = parseSliceParams(x, begin, size), begin_ = _b[0], size_ = _b[1];
var paddings = [];
for (var i = 0; i < dy.rank; i++) {
paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
}
return {x: function() {
return pad(dy, paddings);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var softmaxGradConfig = {
kernelName: Softmax,
outputsToSave: [true],
gradFunc: function(dy, saved, attrs) {
var y = saved[0];
var dim = attrs.dim;
var keepDims = true;
var dyTimesY = mul(dy, y);
return {
logits: function() {
return sub(dyTimesY, mul(sum$1(dyTimesY, [dim], keepDims), y));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var softplusGradConfig = {
kernelName: Softplus,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(dy, sigmoid(x));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var spaceToBatchNDGradConfig = {
kernelName: SpaceToBatchND,
gradFunc: function(dy, saved, attrs) {
var _a = attrs, blockShape = _a.blockShape, paddings = _a.paddings;
return {x: function() {
return batchToSpaceND(dy, blockShape, paddings);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var splitVGradConfig = {
kernelName: SplitV,
gradFunc: function(dy, saved, attrs) {
var axis = attrs.axis;
return {x: function() {
return concat(dy, axis);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sqrtGradConfig = {
kernelName: Sqrt,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return div(dy, mul(sqrt(cast(x, "float32")), 2));
}};
}
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var squareGradConfig = {
kernelName: Square,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return mul(dy, mul(cast(x, "float32"), 2));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var squaredDifferenceGradConfig = {
kernelName: SquaredDifference,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var two = scalar(2);
var derA = function() {
return mul(dy, mul(two, sub(a, b)));
};
var derB = function() {
return mul(dy, mul(two, sub(b, a)));
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var stepGradConfig = {
kernelName: Step,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var subGradConfig = {
kernelName: Sub,
inputsToSave: ["a", "b"],
gradFunc: function(dy, saved) {
var a = saved[0], b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function() {
var res = dy;
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, a.shape);
};
var derB = function() {
var res = dy;
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(neg(res), b.shape);
};
return {a: derA, b: derB};
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sumGradConfig = {
kernelName: Sum,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var expandedDyShape = x.shape.slice();
var axis = attrs.axis;
var axes = parseAxisParam(axis, x.shape);
axes.forEach(function(axis2) {
expandedDyShape[axis2] = 1;
});
var expandedDy = reshape(dy, expandedDyShape);
var derX = mul(expandedDy, ones$1(x.shape, "float32"));
return {x: function() {
return derX;
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tanGradConfig = {
kernelName: Tan,
inputsToSave: ["x"],
gradFunc: function(dy, saved) {
var x = saved[0];
return {x: function() {
return div(dy, square(cos(x)));
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tanhGradConfig = {
kernelName: Tanh,
outputsToSave: [true],
gradFunc: function(dy, saved) {
var y = saved[0];
return {x: function() {
return mul(sub(scalar(1), square(y)), dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tileGradConfig = {
kernelName: Tile,
inputsToSave: ["x"],
gradFunc: function(dy, saved, attrs) {
var x = saved[0];
var reps = attrs.reps;
var derX = function() {
var xGrad = zerosLike(x);
if (x.rank === 1) {
for (var i = 0; i < reps[0]; ++i) {
xGrad = add$1(xGrad, slice(dy, [i * x.shape[0]], [x.shape[0]]));
}
} else if (x.rank === 2) {
for (var i = 0; i < reps[0]; ++i) {
for (var j = 0; j < reps[1]; ++j) {
xGrad = add$1(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1]], [
x.shape[0],
x.shape[1]
]));
}
}
} else if (x.rank === 3) {
for (var i = 0; i < reps[0]; ++i) {
for (var j = 0; j < reps[1]; ++j) {
for (var k = 0; k < reps[2]; ++k) {
xGrad = add$1(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
}
}
}
} else if (x.rank === 4) {
for (var i = 0; i < reps[0]; ++i) {
for (var j = 0; j < reps[1]; ++j) {
for (var k = 0; k < reps[2]; ++k) {
for (var l = 0; l < reps[3]; ++l) {
xGrad = add$1(xGrad, slice(dy, [
i * x.shape[0],
j * x.shape[1],
k * x.shape[2],
l * x.shape[3]
], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
}
}
}
}
} else {
throw new Error("Gradient for tile operation is not implemented for rank-" + (x.rank + " tensors yet."));
}
return xGrad;
};
return {x: derX};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var transposeGradConfig = {
kernelName: Transpose,
gradFunc: function(dy, saved, attrs) {
var transposeAttrs = attrs;
var perm = transposeAttrs.perm;
var undoPerm = getUndoAxesPermutation(perm);
return {x: function() {
return transpose(dy, undoPerm);
}};
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var unpackGradConfig = {
kernelName: Unpack,
gradFunc: function(dy, saved, attrs) {
var unpackAttrs = attrs;
var axis = unpackAttrs.axis;
return {value: function() {
return stack(dy, axis);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var unsortedSegmentSumGradConfig = {
kernelName: UnsortedSegmentSum,
inputsToSave: ["segmentIds"],
gradFunc: function(dy, saved) {
var segmentIds = saved[0];
var derX = function() {
return gatherDropNegatives(dy, segmentIds);
};
return {x: derX};
}
};
function gatherDropNegatives(x, indices) {
var zeroClippedIndices = maximum(indices, zerosLike(indices));
var gathered = gather(x, zeroClippedIndices);
var isPositive = greaterEqual(indices, scalar(0, "int32"));
var numIters = gathered.rank - isPositive.rank;
for (var i = 0; i < numIters; ++i) {
isPositive = expandDims(isPositive, i + 1);
}
isPositive = logicalAnd(isPositive, ones$1(gathered.shape, "bool"));
var zeroSlice = zerosLike(gathered);
return where(isPositive, gathered, zeroSlice);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var zerosLikeGradConfig = {
kernelName: ZerosLike,
gradFunc: function(dy) {
return {x: function() {
return zerosLike(dy);
}};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var gradConfigs = [
absGradConfig,
acosGradConfig,
acoshGradConfig,
addGradConfig,
addNGradConfig,
argMaxGradConfig,
argMinGradConfig,
asinGradConfig,
asinhGradConfig,
atan2GradConfig,
atanGradConfig,
atanhGradConfig,
avgPool3DGradConfig,
avgPoolGradConfig,
batchMatMulGradConfig,
batchToSpaceNDGradConfig,
broadcastToGradConfig,
castGradConfig,
ceilGradConfig,
clipByValueGradConfig,
concatGradConfig,
conv2DBackpropInputGradConfig,
conv2DGradConfig,
conv3DGradConfig,
cosGradConfig,
coshGradConfig,
cumsumGradConfig,
depthwiseConv2dNativeGradConfig,
dilation2dGradConfig,
divGradConfig,
eluGradConfig,
erfGradConfig,
expGradConfig,
expm1GradConfig,
floorDivGradConfig,
floorGradConfig,
fusedBatchNormGradConfig,
gatherGradConfig,
greaterEqualGradConfig,
identityGradConfig,
isFiniteGradConfig,
isInfGradConfig,
isNanGradConfig,
log1pGradConfig,
logGradConfig,
logSoftmaxGradConfig,
lrnGradConfig,
maxGradConfig,
maxGradConfig,
maximumGradConfig,
maxPool3DGradConfig,
maxPoolGradConfig,
minGradConfig,
minimumGradConfig,
mirrorPadGradConfig,
modGradConfig,
multiplyGradConfig,
negateGradConfig,
oneHotGradConfig,
onesLikeGradConfig,
padV2GradConfig,
padV2GradConfig,
powGradConfig,
preluGradConfig,
reciprocalGradConfig,
relu6GradConfig,
reluGradConfig,
reshapeGradConfig,
resizeBilinearGradConfig,
resizeNearestNeighborGradConfig,
reverseGradConfig,
roundGradConfig,
rsqrtGradConfig,
selectV2PoolGradConfig,
seluGradConfig,
sigmoidGradConfig,
signGradConfig,
sinGradConfig,
sinhGradConfig,
sliceGradConfig,
softmaxGradConfig,
softplusGradConfig,
spaceToBatchNDGradConfig,
spaceToBatchNDGradConfig,
splitVGradConfig,
splitVGradConfig,
sqrtGradConfig,
squaredDifferenceGradConfig,
squareGradConfig,
stepGradConfig,
subGradConfig,
sumGradConfig,
tanGradConfig,
tanhGradConfig,
tileGradConfig,
transposeGradConfig,
unpackGradConfig,
unsortedSegmentSumGradConfig,
zerosLikeGradConfig
];
for (var _i = 0, gradConfigs_1 = gradConfigs; _i < gradConfigs_1.length; _i++) {
var gradientConfig = gradConfigs_1[_i];
registerGradient(gradientConfig);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.abs = function() {
this.throwIfDisposed();
return abs(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.acos = function() {
this.throwIfDisposed();
return acos(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.acosh = function() {
this.throwIfDisposed();
return acosh(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.addStrict = function(x) {
this.throwIfDisposed();
return addStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.add = function(b) {
this.throwIfDisposed();
return add$1(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.all = function(axis, keepDims) {
this.throwIfDisposed();
return all(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.any = function(axis, keepDims) {
this.throwIfDisposed();
return any(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.argMax = function(axis) {
this.throwIfDisposed();
return argMax(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.argMin = function(axis) {
this.throwIfDisposed();
return argMin(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.asScalar = function() {
this.throwIfDisposed();
assert(this.size === 1, function() {
return "The array must have only 1 element.";
});
return reshape(this, []);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.asType = function(dtype) {
this.throwIfDisposed();
return cast(this, dtype);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.as1D = function() {
this.throwIfDisposed();
return reshape(this, [this.size]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.as2D = function(rows, columns) {
this.throwIfDisposed();
return reshape(this, [rows, columns]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.as3D = function(rows, columns, depth) {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.as4D = function(rows, columns, depth, depth2) {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth, depth2]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.as5D = function(rows, columns, depth, depth2, depth3) {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth, depth2, depth3]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.asin = function() {
this.throwIfDisposed();
return asin(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.asinh = function() {
this.throwIfDisposed();
return asinh(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.atan = function() {
this.throwIfDisposed();
return atan(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.atan2 = function(b) {
this.throwIfDisposed();
return atan2(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.atanh = function() {
this.throwIfDisposed();
return atanh(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.avgPool = function(filterSize, strides, pad2, dimRoundingMode) {
this.throwIfDisposed();
return avgPool(this, filterSize, strides, pad2, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.batchToSpaceND = function(blockShape, crops) {
this.throwIfDisposed();
return batchToSpaceND(this, blockShape, crops);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.batchNorm = function(mean2, variance, offset, scale, varianceEpsilon) {
this.throwIfDisposed();
return batchNorm(this, mean2, variance, offset, scale, varianceEpsilon);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.broadcastTo = function(shape) {
this.throwIfDisposed();
return broadcastTo(this, shape);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.cast = function(dtype) {
this.throwIfDisposed();
return cast(this, dtype);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.ceil = function() {
this.throwIfDisposed();
return ceil(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.clipByValue = function(min2, max2) {
this.throwIfDisposed();
return clipByValue(this, min2, max2);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.concat = function(x, axis) {
this.throwIfDisposed();
if (x instanceof Tensor) {
x = [x];
}
return concat([this].concat(x), axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.conv1d = function(filter, stride, pad2, dataFormat, dilation, dimRoundingMode) {
this.throwIfDisposed();
return conv1d(this, filter, stride, pad2, dataFormat, dilation, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.conv2dTranspose = function(filter, outputShape, strides, pad2, dimRoundingMode) {
this.throwIfDisposed();
return conv2dTranspose(this, filter, outputShape, strides, pad2, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.conv2d = function(filter, strides, pad2, dataFormat, dilations, dimRoundingMode) {
this.throwIfDisposed();
return conv2d(this, filter, strides, pad2, dataFormat, dilations, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.cos = function() {
this.throwIfDisposed();
return cos(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.cosh = function() {
this.throwIfDisposed();
return cosh(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.cumsum = function(axis, exclusive, reverse2) {
this.throwIfDisposed();
return cumsum(this, axis, exclusive, reverse2);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.depthToSpace = function(blockSize, dataFormat) {
this.throwIfDisposed();
return depthToSpace(this, blockSize, dataFormat);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.depthwiseConv2D = function(filter, strides, pad2, dataFormat, dilations, dimRoundingMode) {
deprecationWarn("depthwiseConv2D is deprecated, use depthwiseConv2d instead");
this.throwIfDisposed();
return depthwiseConv2d(this, filter, strides, pad2, dataFormat, dilations, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.depthwiseConv2d = function(filter, strides, pad2, dataFormat, dilations, dimRoundingMode) {
this.throwIfDisposed();
return depthwiseConv2d(this, filter, strides, pad2, dataFormat, dilations, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.dilation2d = function(filter, strides, pad2, dilations, dataFormat) {
this.throwIfDisposed();
return dilation2d(this, filter, strides, pad2, dilations, dataFormat);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.divNoNan = function(b) {
this.throwIfDisposed();
return divNoNan(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.divStrict = function(x) {
this.throwIfDisposed();
return divStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.div = function(b) {
this.throwIfDisposed();
return div(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.dot = function(b) {
this.throwIfDisposed();
return dot(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.elu = function() {
this.throwIfDisposed();
return elu(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.equalStrict = function(x) {
this.throwIfDisposed();
return equalStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.equal = function(b) {
this.throwIfDisposed();
return equal(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.erf = function() {
this.throwIfDisposed();
return erf(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.exp = function() {
this.throwIfDisposed();
return exp(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.expandDims = function(axis) {
this.throwIfDisposed();
return expandDims(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.expm1 = function() {
this.throwIfDisposed();
return expm1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.fft = function() {
this.throwIfDisposed();
return fft(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.flatten = function() {
this.throwIfDisposed();
return reshape(this, [this.size]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.floor = function() {
this.throwIfDisposed();
return floor(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.floorDiv = function(b) {
this.throwIfDisposed();
return floorDiv(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.gather = function(indices, axis) {
this.throwIfDisposed();
return gather(this, indices, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.greaterEqualStrict = function(x) {
this.throwIfDisposed();
return greaterEqualStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.greaterEqual = function(b) {
this.throwIfDisposed();
return greaterEqual(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.greaterStrict = function(x) {
this.throwIfDisposed();
return greaterStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.greater = function(b) {
this.throwIfDisposed();
return greater(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.ifft = function() {
this.throwIfDisposed();
return ifft(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.irfft = function() {
this.throwIfDisposed();
return irfft(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.isFinite = function() {
this.throwIfDisposed();
return isFinite$1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.isInf = function() {
this.throwIfDisposed();
return isInf(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.isNaN = function() {
this.throwIfDisposed();
return isNaN$1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.leakyRelu = function(alpha) {
this.throwIfDisposed();
return leakyRelu(this, alpha);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.lessEqualStrict = function(x) {
this.throwIfDisposed();
return lessEqualStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.lessEqual = function(b) {
this.throwIfDisposed();
return lessEqual(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.lessStrict = function(x) {
this.throwIfDisposed();
return lessStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.less = function(b) {
this.throwIfDisposed();
return less(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.localResponseNormalization = function(depthRadius, bias, alpha, beta) {
this.throwIfDisposed();
return localResponseNormalization(this, depthRadius, bias, alpha, beta);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.logSigmoid = function() {
this.throwIfDisposed();
return logSigmoid(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.logSoftmax = function(axis) {
this.throwIfDisposed();
return logSoftmax(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.logSumExp = function(axis, keepDims) {
this.throwIfDisposed();
return logSumExp(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.log = function() {
this.throwIfDisposed();
return log(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.log1p = function() {
this.throwIfDisposed();
return log1p(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.logicalAnd = function(b) {
this.throwIfDisposed();
return logicalAnd(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.logicalNot = function() {
this.throwIfDisposed();
return logicalNot(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.logicalOr = function(b) {
this.throwIfDisposed();
return logicalOr(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.logicalXor = function(b) {
this.throwIfDisposed();
return logicalXor(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.matMul = function(b, transposeA, transposeB) {
this.throwIfDisposed();
return matMul(this, b, transposeA, transposeB);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.maxPool = function(filterSize, strides, pad2, dimRoundingMode) {
this.throwIfDisposed();
return maxPool(this, filterSize, strides, pad2, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.max = function(axis, keepDims) {
this.throwIfDisposed();
return max(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.maximumStrict = function(x) {
this.throwIfDisposed();
return maximumStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.maximum = function(b) {
this.throwIfDisposed();
return maximum(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.mean = function(axis, keepDims) {
this.throwIfDisposed();
return mean(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.min = function(axis, keepDims) {
this.throwIfDisposed();
return min(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.minimumStrict = function(x) {
this.throwIfDisposed();
return minimumStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.minimum = function(b) {
this.throwIfDisposed();
return minimum(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.mirrorPad = function(paddings, mode) {
this.throwIfDisposed();
return mirrorPad(this, paddings, mode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.modStrict = function(x) {
this.throwIfDisposed();
return modStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.mod = function(b) {
this.throwIfDisposed();
return mod(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.mulStrict = function(x) {
this.throwIfDisposed();
return mulStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.mul = function(b) {
this.throwIfDisposed();
return mul(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.neg = function() {
this.throwIfDisposed();
return neg(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.norm = function(ord, axis, keepDims) {
this.throwIfDisposed();
return norm(this, ord, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.notEqualStrict = function(x) {
this.throwIfDisposed();
return notEqualStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.notEqual = function(b) {
this.throwIfDisposed();
return notEqual(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.oneHot = function(depth, onValue, offValue) {
if (onValue === void 0) {
onValue = 1;
}
if (offValue === void 0) {
offValue = 0;
}
this.throwIfDisposed();
return oneHot(this, depth, onValue, offValue);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.onesLike = function() {
this.throwIfDisposed();
return onesLike(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.pad = function(paddings, constantValue) {
this.throwIfDisposed();
return pad(this, paddings, constantValue);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.pool = function(windowShape, poolingType, padding, dilationRate, strides) {
this.throwIfDisposed();
return pool(this, windowShape, poolingType, padding, dilationRate, strides);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.powStrict = function(exp2) {
this.throwIfDisposed();
return powStrict(this, exp2);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.pow = function(exp2) {
this.throwIfDisposed();
return pow(this, exp2);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.prelu = function(alpha) {
this.throwIfDisposed();
return prelu(this, alpha);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.prod = function(axis, keepDims) {
this.throwIfDisposed();
return prod(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.reciprocal = function() {
this.throwIfDisposed();
return reciprocal(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.relu = function() {
this.throwIfDisposed();
return relu(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.relu6 = function() {
this.throwIfDisposed();
return relu6(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.reshapeAs = function(x) {
this.throwIfDisposed();
return reshape(this, x.shape);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.reshape = function(shape) {
this.throwIfDisposed();
return reshape(this, shape);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.resizeBilinear = function(newShape2D, alignCorners) {
this.throwIfDisposed();
return resizeBilinear(this, newShape2D, alignCorners);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.resizeNearestNeighbor = function(newShape2D, alignCorners) {
this.throwIfDisposed();
return resizeNearestNeighbor(this, newShape2D, alignCorners);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.reverse = function(axis) {
this.throwIfDisposed();
return reverse(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.rfft = function() {
this.throwIfDisposed();
return rfft(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.round = function() {
this.throwIfDisposed();
return round(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.rsqrt = function() {
this.throwIfDisposed();
return rsqrt(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.selu = function() {
this.throwIfDisposed();
return selu(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.separableConv2d = function(depthwiseFilter, pointwiseFilter, strides, pad2, dilation, dataFormat) {
this.throwIfDisposed();
return separableConv2d(this, depthwiseFilter, pointwiseFilter, strides, pad2, dilation, dataFormat);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.sigmoid = function() {
this.throwIfDisposed();
return sigmoid(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.sign = function() {
this.throwIfDisposed();
return sign(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.sin = function() {
this.throwIfDisposed();
return sin(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.sinh = function() {
this.throwIfDisposed();
return sinh(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.slice = function(begin, size) {
this.throwIfDisposed();
return slice(this, begin, size);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.softmax = function(dim) {
this.throwIfDisposed();
return softmax(this, dim);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.softplus = function() {
this.throwIfDisposed();
return softplus(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.spaceToBatchND = function(blockShape, paddings) {
this.throwIfDisposed();
return spaceToBatchND(this, blockShape, paddings);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.split = function(numOrSizeSplits, axis) {
this.throwIfDisposed();
return split(this, numOrSizeSplits, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.sqrt = function() {
this.throwIfDisposed();
return sqrt(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.square = function() {
this.throwIfDisposed();
return square(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.squaredDifference = function(b) {
this.throwIfDisposed();
return squaredDifference(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.squaredDifferenceStrict = function(x) {
this.throwIfDisposed();
return squaredDifferenceStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.squeeze = function(axis) {
this.throwIfDisposed();
return squeeze(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.stack = function(x, axis) {
this.throwIfDisposed();
var tensorsToBeStacked = x instanceof Tensor ? [this, x] : [this].concat(x);
return stack(tensorsToBeStacked, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.step = function(alpha) {
this.throwIfDisposed();
return step(this, alpha);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.stridedSlice = function(begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
this.throwIfDisposed();
return stridedSlice(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.subStrict = function(x) {
this.throwIfDisposed();
return subStrict(this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.sub = function(b) {
this.throwIfDisposed();
return sub(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.sum = function(axis, keepDims) {
this.throwIfDisposed();
return sum$1(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.tan = function() {
this.throwIfDisposed();
return tan(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.tanh = function() {
this.throwIfDisposed();
return tanh$1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.tile = function(reps) {
this.throwIfDisposed();
return tile(this, reps);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.toBool = function() {
this.throwIfDisposed();
return cast(this, "bool");
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.toFloat = function() {
this.throwIfDisposed();
return cast(this, "float32");
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.toInt = function() {
this.throwIfDisposed();
return cast(this, "int32");
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.topk = function(k, sorted) {
this.throwIfDisposed();
return topk(this, k, sorted);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.transpose = function(perm) {
this.throwIfDisposed();
return transpose(this, perm);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.unique = function(axis) {
this.throwIfDisposed();
return unique(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.unsortedSegmentSum = function(segmentIds, numSegments) {
this.throwIfDisposed();
return unsortedSegmentSum(this, segmentIds, numSegments);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.unstack = function(axis) {
this.throwIfDisposed();
return unstack(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.where = function(condition, x) {
this.throwIfDisposed();
return where(condition, this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Tensor.prototype.zerosLike = function() {
this.throwIfDisposed();
return zerosLike(this);
};
exports.Abs = Abs;
exports.Acos = Acos;
exports.Acosh = Acosh;
exports.AdadeltaOptimizer = AdadeltaOptimizer;
exports.AdagradOptimizer = AdagradOptimizer;
exports.AdamOptimizer = AdamOptimizer;
exports.AdamaxOptimizer = AdamaxOptimizer;
exports.Add = Add;
exports.AddN = AddN;
exports.All = All;
exports.Any = Any;
exports.ArgMax = ArgMax;
exports.ArgMin = ArgMin;
exports.Asin = Asin;
exports.Asinh = Asinh;
exports.Atan = Atan;
exports.Atan2 = Atan2;
exports.Atanh = Atanh;
exports.AvgPool = AvgPool;
exports.AvgPool3D = AvgPool3D;
exports.AvgPool3DBackprop = AvgPool3DBackprop;
exports.AvgPoolBackprop = AvgPoolBackprop;
exports.BatchMatMul = BatchMatMul;
exports.BatchToSpaceND = BatchToSpaceND;
exports.BroadcastTo = BroadcastTo;
exports.Cast = Cast;
exports.Ceil = Ceil;
exports.ClipByValue = ClipByValue;
exports.Complex = Complex;
exports.Concat = Concat;
exports.Conv2D = Conv2D;
exports.Conv2DBackpropFilter = Conv2DBackpropFilter;
exports.Conv2DBackpropInput = Conv2DBackpropInput;
exports.Conv3D = Conv3D;
exports.Conv3DBackpropFilterV2 = Conv3DBackpropFilterV2;
exports.Conv3DBackpropInputV2 = Conv3DBackpropInputV2;
exports.Cos = Cos;
exports.Cosh = Cosh;
exports.CropAndResize = CropAndResize;
exports.Cumsum = Cumsum;
exports.DataStorage = DataStorage;
exports.DepthToSpace = DepthToSpace;
exports.DepthwiseConv2dNative = DepthwiseConv2dNative;
exports.DepthwiseConv2dNativeBackpropFilter = DepthwiseConv2dNativeBackpropFilter;
exports.DepthwiseConv2dNativeBackpropInput = DepthwiseConv2dNativeBackpropInput;
exports.Diag = Diag;
exports.Dilation2D = Dilation2D;
exports.Dilation2DBackpropFilter = Dilation2DBackpropFilter;
exports.Dilation2DBackpropInput = Dilation2DBackpropInput;
exports.Div = Div;
exports.Elu = Elu;
exports.EluGrad = EluGrad;
exports.Environment = Environment;
exports.Equal = Equal;
exports.Erf = Erf;
exports.Exp = Exp;
exports.Expm1 = Expm1;
exports.FFT = FFT;
exports.Fill = Fill;
exports.FlipLeftRight = FlipLeftRight;
exports.Floor = Floor;
exports.FloorDiv = FloorDiv;
exports.FromPixels = FromPixels;
exports.FusedBatchNorm = FusedBatchNorm;
exports.FusedConv2D = FusedConv2D;
exports.FusedDepthwiseConv2D = FusedDepthwiseConv2D;
exports.GatherNd = GatherNd;
exports.GatherV2 = GatherV2;
exports.Greater = Greater;
exports.GreaterEqual = GreaterEqual;
exports.IFFT = IFFT;
exports.Identity = Identity;
exports.Imag = Imag;
exports.IsFinite = IsFinite;
exports.IsInf = IsInf;
exports.IsNan = IsNan;
exports.KernelBackend = KernelBackend;
exports.LRN = LRN;
exports.LRNBackprop = LRNBackprop;
exports.Less = Less;
exports.LessEqual = LessEqual;
exports.LinSpace = LinSpace;
exports.Log = Log;
exports.Log1p = Log1p;
exports.LogSoftmax = LogSoftmax;
exports.LogicalAnd = LogicalAnd;
exports.LogicalNot = LogicalNot;
exports.LogicalOr = LogicalOr;
exports.Max = Max;
exports.MaxPool = MaxPool;
exports.MaxPool3D = MaxPool3D;
exports.MaxPool3DBackprop = MaxPool3DBackprop;
exports.MaxPoolBackprop = MaxPoolBackprop;
exports.MaxPoolWithArgmax = MaxPoolWithArgmax;
exports.Maximum = Maximum;
exports.Mean = Mean;
exports.Min = Min;
exports.Minimum = Minimum;
exports.MirrorPad = MirrorPad;
exports.Mod = Mod;
exports.MomentumOptimizer = MomentumOptimizer;
exports.Multiply = Multiply;
exports.Negate = Negate;
exports.NonMaxSuppressionV3 = NonMaxSuppressionV3;
exports.NonMaxSuppressionV4 = NonMaxSuppressionV4;
exports.NonMaxSuppressionV5 = NonMaxSuppressionV5;
exports.NotEqual = NotEqual;
exports.OP_SCOPE_SUFFIX = OP_SCOPE_SUFFIX;
exports.OneHot = OneHot;
exports.OnesLike = OnesLike;
exports.Optimizer = Optimizer;
exports.PadV2 = PadV2;
exports.Pool = Pool;
exports.Pow = Pow;
exports.Prelu = Prelu;
exports.Prod = Prod;
exports.RMSPropOptimizer = RMSPropOptimizer;
exports.Range = Range;
exports.Real = Real;
exports.Reciprocal = Reciprocal;
exports.Relu = Relu;
exports.Relu6 = Relu6;
exports.Reshape = Reshape;
exports.ResizeBilinear = ResizeBilinear;
exports.ResizeBilinearGrad = ResizeBilinearGrad;
exports.ResizeNearestNeighbor = ResizeNearestNeighbor;
exports.ResizeNearestNeighborGrad = ResizeNearestNeighborGrad;
exports.Reverse = Reverse;
exports.RotateWithOffset = RotateWithOffset;
exports.Round = Round;
exports.Rsqrt = Rsqrt;
exports.SGDOptimizer = SGDOptimizer;
exports.ScatterNd = ScatterNd;
exports.SelectV2 = SelectV2;
exports.Selu = Selu;
exports.Sigmoid = Sigmoid;
exports.Sign = Sign;
exports.Sin = Sin;
exports.Sinh = Sinh;
exports.Slice = Slice;
exports.Softmax = Softmax;
exports.Softplus = Softplus;
exports.SpaceToBatchND = SpaceToBatchND;
exports.SparseToDense = SparseToDense;
exports.SplitV = SplitV;
exports.Sqrt = Sqrt;
exports.Square = Square;
exports.SquaredDifference = SquaredDifference;
exports.Step = Step;
exports.StridedSlice = StridedSlice;
exports.Sub = Sub;
exports.Sum = Sum;
exports.Tan = Tan;
exports.Tanh = Tanh;
exports.Tensor = Tensor;
exports.TensorBuffer = TensorBuffer;
exports.Tile = Tile;
exports.TopK = TopK;
exports.Transpose = Transpose;
exports.Unique = Unique;
exports.Unpack = Unpack;
exports.UnsortedSegmentSum = UnsortedSegmentSum;
exports.Variable = Variable;
exports.ZerosLike = ZerosLike;
exports._FusedMatMul = _FusedMatMul;
exports.abs = abs;
exports.acos = acos;
exports.acosh = acosh;
exports.add = add$1;
exports.addN = addN;
exports.addStrict = addStrict;
exports.all = all;
exports.any = any;
exports.argMax = argMax;
exports.argMin = argMin;
exports.asin = asin;
exports.asinh = asinh;
exports.atan = atan;
exports.atan2 = atan2;
exports.atanh = atanh;
exports.avgPool = avgPool;
exports.avgPool3d = avgPool3d;
exports.backend = backend;
exports.backend_util = backend_util;
exports.basicLSTMCell = basicLSTMCell;
exports.batchNorm = batchNorm;
exports.batchNorm2d = batchNorm2d;
exports.batchNorm3d = batchNorm3d;
exports.batchNorm4d = batchNorm4d;
exports.batchToSpaceND = batchToSpaceND;
exports.booleanMaskAsync = booleanMaskAsync;
exports.broadcastTo = broadcastTo;
exports.browser = browser;
exports.buffer = buffer;
exports.cast = cast;
exports.ceil = ceil;
exports.clipByValue = clipByValue;
exports.clone = clone;
exports.complex = complex;
exports.concat = concat;
exports.concat1d = concat1d;
exports.concat2d = concat2d;
exports.concat3d = concat3d;
exports.concat4d = concat4d;
exports.conv1d = conv1d;
exports.conv2d = conv2d;
exports.conv2dTranspose = conv2dTranspose;
exports.conv3d = conv3d;
exports.conv3dTranspose = conv3dTranspose;
exports.copyRegisteredKernels = copyRegisteredKernels;
exports.cos = cos;
exports.cosh = cosh;
exports.cosineWindow = cosineWindow;
exports.cumsum = cumsum;
exports.customGrad = customGrad;
exports.deprecationWarn = deprecationWarn;
exports.depthToSpace = depthToSpace;
exports.depthwiseConv2d = depthwiseConv2d;
exports.device_util = device_util;
exports.diag = diag;
exports.dilation2d = dilation2d;
exports.disableDeprecationWarnings = disableDeprecationWarnings;
exports.dispose = dispose;
exports.disposeVariables = disposeVariables;
exports.div = div;
exports.divNoNan = divNoNan;
exports.divStrict = divStrict;
exports.dot = dot;
exports.dropout = dropout;
exports.elu = elu;
exports.enableDebugMode = enableDebugMode;
exports.enableProdMode = enableProdMode;
exports.enclosingPowerOfTwo = enclosingPowerOfTwo;
exports.engine = engine;
exports.env = env;
exports.equal = equal;
exports.equalStrict = equalStrict;
exports.erf = erf;
exports.exp = exp;
exports.expandDims = expandDims;
exports.expm1 = expm1;
exports.eye = eye;
exports.fft = fft;
exports.fill = fill;
exports.findBackend = findBackend;
exports.findBackendFactory = findBackendFactory;
exports.floor = floor;
exports.floorDiv = floorDiv;
exports.fused = fused_ops;
exports.gather = gather;
exports.gatherND = gatherND;
exports.gather_util = gather_nd_util;
exports.getBackend = getBackend;
exports.getGradient = getGradient;
exports.getKernel = getKernel;
exports.getKernelsForBackend = getKernelsForBackend;
exports.grad = grad;
exports.grads = grads;
exports.greater = greater;
exports.greaterEqual = greaterEqual;
exports.greaterEqualStrict = greaterEqualStrict;
exports.greaterStrict = greaterStrict;
exports.ifft = ifft;
exports.imag = imag;
exports.image = image;
exports.inTopKAsync = inTopKAsync;
exports.io = io;
exports.irfft = irfft;
exports.isFinite = isFinite$1;
exports.isInf = isInf;
exports.isNaN = isNaN$1;
exports.keep = keep;
exports.kernel_impls = kernel_impls;
exports.leakyRelu = leakyRelu;
exports.less = less;
exports.lessEqual = lessEqual;
exports.lessEqualStrict = lessEqualStrict;
exports.lessStrict = lessStrict;
exports.linalg = linalg;
exports.linspace = linspace;
exports.localResponseNormalization = localResponseNormalization;
exports.log = log;
exports.log1p = log1p;
exports.logSigmoid = logSigmoid;
exports.logSoftmax = logSoftmax;
exports.logSumExp = logSumExp;
exports.logicalAnd = logicalAnd;
exports.logicalNot = logicalNot;
exports.logicalOr = logicalOr;
exports.logicalXor = logicalXor;
exports.losses = losses;
exports.matMul = matMul;
exports.math = math;
exports.max = max;
exports.maxPool = maxPool;
exports.maxPool3d = maxPool3d;
exports.maxPoolWithArgmax = maxPoolWithArgmax;
exports.maximum = maximum;
exports.maximumStrict = maximumStrict;
exports.mean = mean;
exports.memory = memory;
exports.min = min;
exports.minimum = minimum;
exports.minimumStrict = minimumStrict;
exports.mirrorPad = mirrorPad;
exports.mod = mod;
exports.modStrict = modStrict;
exports.moments = moments;
exports.movingAverage = movingAverage;
exports.mul = mul;
exports.mulStrict = mulStrict;
exports.multiRNNCell = multiRNNCell;
exports.multinomial = multinomial;
exports.neg = neg;
exports.nextFrame = nextFrame;
exports.norm = norm;
exports.notEqual = notEqual;
exports.notEqualStrict = notEqualStrict;
exports.oneHot = oneHot;
exports.ones = ones$1;
exports.onesLike = onesLike;
exports.op = op;
exports.outerProduct = outerProduct;
exports.pad = pad;
exports.pad1d = pad1d;
exports.pad2d = pad2d;
exports.pad3d = pad3d;
exports.pad4d = pad4d;
exports.pool = pool;
exports.pow = pow;
exports.powStrict = powStrict;
exports.prelu = prelu;
exports.print = print;
exports.prod = prod;
exports.profile = profile;
exports.rand = rand;
exports.randomGamma = randomGamma;
exports.randomNormal = randomNormal;
exports.randomUniform = randomUniform;
exports.range = range;
exports.ready = ready;
exports.real = real;
exports.reciprocal = reciprocal;
exports.registerBackend = registerBackend;
exports.registerGradient = registerGradient;
exports.registerKernel = registerKernel;
exports.relu = relu;
exports.relu6 = relu6;
exports.removeBackend = removeBackend;
exports.reshape = reshape;
exports.reverse = reverse;
exports.reverse1d = reverse1d;
exports.reverse2d = reverse2d;
exports.reverse3d = reverse3d;
exports.reverse4d = reverse4d;
exports.rfft = rfft;
exports.round = round;
exports.rsqrt = rsqrt;
exports.scalar = scalar;
exports.scatterND = scatterND;
exports.scatter_util = scatter_nd_util;
exports.selu = selu;
exports.separableConv2d = separableConv2d;
exports.serialization = serialization;
exports.setBackend = setBackend;
exports.setPlatform = setPlatform;
exports.setdiff1dAsync = setdiff1dAsync;
exports.sigmoid = sigmoid;
exports.sign = sign;
exports.signal = signal;
exports.sin = sin;
exports.sinh = sinh;
exports.slice = slice;
exports.slice1d = slice1d;
exports.slice2d = slice2d;
exports.slice3d = slice3d;
exports.slice4d = slice4d;
exports.slice_util = slice_util;
exports.softmax = softmax;
exports.softplus = softplus;
exports.spaceToBatchND = spaceToBatchND;
exports.sparseToDense = sparseToDense;
exports.spectral = spectral;
exports.split = split;
exports.sqrt = sqrt;
exports.square = square;
exports.squaredDifference = squaredDifference;
exports.squaredDifferenceStrict = squaredDifferenceStrict;
exports.squeeze = squeeze;
exports.stack = stack;
exports.step = step;
exports.stridedSlice = stridedSlice;
exports.sub = sub;
exports.subStrict = subStrict;
exports.sum = sum$1;
exports.sumOutType = sumOutType;
exports.tan = tan;
exports.tanh = tanh$1;
exports.tensor = tensor;
exports.tensor1d = tensor1d;
exports.tensor2d = tensor2d;
exports.tensor3d = tensor3d;
exports.tensor4d = tensor4d;
exports.tensor5d = tensor5d;
exports.tensor6d = tensor6d;
exports.tensor_util = tensor_util;
exports.test_util = test_util;
exports.tidy = tidy;
exports.tile = tile;
exports.time = time;
exports.topk = topk;
exports.train = train;
exports.transpose = transpose;
exports.truncatedNormal = truncatedNormal;
exports.unique = unique;
exports.unregisterGradient = unregisterGradient;
exports.unregisterKernel = unregisterKernel;
exports.unsortedSegmentSum = unsortedSegmentSum;
exports.unstack = unstack;
exports.upcastType = upcastType;
exports.util = util;
exports.valueAndGrad = valueAndGrad;
exports.valueAndGrads = valueAndGrads;
exports.variable = variable;
exports.variableGrads = variableGrads;
exports.version_core = version;
exports.where = where;
exports.whereAsync = whereAsync;
exports.zeros = zeros;
exports.zerosLike = zerosLike;
});
// node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js
var require_tf_layers_node = __commonJS((exports) => {
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
"use strict";
Object.defineProperty(exports, "__esModule", {value: true});
var tfc = require_tf_core_node();
/*! *****************************************************************************
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at http://www.apache.org/licenses/LICENSE-2.0
THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
MERCHANTABLITY OR NON-INFRINGEMENT.
See the Apache Version 2.0 License for specific language governing permissions
and limitations under the License.
***************************************************************************** */
var extendStatics = function(d, b) {
extendStatics = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(d2, b2) {
d2.__proto__ = b2;
} || function(d2, b2) {
for (var p in b2)
if (b2.hasOwnProperty(p))
d2[p] = b2[p];
};
return extendStatics(d, b);
};
function __extends(d, b) {
extendStatics(d, b);
function __() {
this.constructor = d;
}
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
}
var __assign = function() {
__assign = Object.assign || function __assign2(t) {
for (var s, i = 1, n = arguments.length; i < n; i++) {
s = arguments[i];
for (var p in s)
if (Object.prototype.hasOwnProperty.call(s, p))
t[p] = s[p];
}
return t;
};
return __assign.apply(this, arguments);
};
function __rest(s, e) {
var t = {};
for (var p in s)
if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)
t[p] = s[p];
if (s != null && typeof Object.getOwnPropertySymbols === "function") {
for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++)
if (e.indexOf(p[i]) < 0)
t[p[i]] = s[p[i]];
}
return t;
}
function __awaiter(thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function(resolve, reject) {
function fulfilled(value) {
try {
step(generator.next(value));
} catch (e) {
reject(e);
}
}
function rejected(value) {
try {
step(generator["throw"](value));
} catch (e) {
reject(e);
}
}
function step(result) {
result.done ? resolve(result.value) : new P(function(resolve2) {
resolve2(result.value);
}).then(fulfilled, rejected);
}
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = {label: 0, sent: function() {
if (t[0] & 1)
throw t[1];
return t[1];
}, trys: [], ops: []}, f, y, t, g;
return g = {next: verb(0), throw: verb(1), return: verb(2)}, typeof Symbol === "function" && (g[Symbol.iterator] = function() {
return this;
}), g;
function verb(n) {
return function(v) {
return step([n, v]);
};
}
function step(op) {
if (f)
throw new TypeError("Generator is already executing.");
while (_)
try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done)
return t;
if (y = 0, t)
op = [op[0] & 2, t.value];
switch (op[0]) {
case 0:
case 1:
t = op;
break;
case 4:
_.label++;
return {value: op[1], done: false};
case 5:
_.label++;
y = op[1];
op = [0];
continue;
case 7:
op = _.ops.pop();
_.trys.pop();
continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) {
_ = 0;
continue;
}
if (op[0] === 3 && (!t || op[1] > t[0] && op[1] < t[3])) {
_.label = op[1];
break;
}
if (op[0] === 6 && _.label < t[1]) {
_.label = t[1];
t = op;
break;
}
if (t && _.label < t[2]) {
_.label = t[2];
_.ops.push(op);
break;
}
if (t[2])
_.ops.pop();
_.trys.pop();
continue;
}
op = body.call(thisArg, _);
} catch (e) {
op = [6, e];
y = 0;
} finally {
f = t = 0;
}
if (op[0] & 5)
throw op[1];
return {value: op[0] ? op[1] : void 0, done: true};
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var _epsilon;
function epsilon() {
if (_epsilon == null) {
_epsilon = tfc.backend().epsilon();
}
return _epsilon;
}
function imageDataFormat() {
return "channelsLast";
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var AttributeError = function(_super) {
__extends(AttributeError2, _super);
function AttributeError2(message) {
var _this = _super.call(this, message) || this;
Object.setPrototypeOf(_this, AttributeError2.prototype);
return _this;
}
return AttributeError2;
}(Error);
var RuntimeError = function(_super) {
__extends(RuntimeError2, _super);
function RuntimeError2(message) {
var _this = _super.call(this, message) || this;
Object.setPrototypeOf(_this, RuntimeError2.prototype);
return _this;
}
return RuntimeError2;
}(Error);
var ValueError = function(_super) {
__extends(ValueError2, _super);
function ValueError2(message) {
var _this = _super.call(this, message) || this;
Object.setPrototypeOf(_this, ValueError2.prototype);
return _this;
}
return ValueError2;
}(Error);
var NotImplementedError = function(_super) {
__extends(NotImplementedError2, _super);
function NotImplementedError2(message) {
var _this = _super.call(this, message) || this;
Object.setPrototypeOf(_this, NotImplementedError2.prototype);
return _this;
}
return NotImplementedError2;
}(Error);
var AssertionError = function(_super) {
__extends(AssertionError2, _super);
function AssertionError2(message) {
var _this = _super.call(this, message) || this;
Object.setPrototypeOf(_this, AssertionError2.prototype);
return _this;
}
return AssertionError2;
}(Error);
var IndexError = function(_super) {
__extends(IndexError2, _super);
function IndexError2(message) {
var _this = _super.call(this, message) || this;
Object.setPrototypeOf(_this, IndexError2.prototype);
return _this;
}
return IndexError2;
}(Error);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function pyListRepeat(value, numValues) {
if (Array.isArray(value)) {
var newArray = [];
for (var i = 0; i < numValues; i++) {
newArray = newArray.concat(value);
}
return newArray;
} else {
var newArray = new Array(numValues);
newArray.fill(value);
return newArray;
}
}
function assert(val, message) {
if (!val) {
throw new AssertionError(message);
}
}
function count(array, refernce) {
var counter = 0;
for (var _i = 0, array_1 = array; _i < array_1.length; _i++) {
var item = array_1[_i];
if (item === refernce) {
counter++;
}
}
return counter;
}
function singletonOrArray(xs) {
if (xs.length === 1) {
return xs[0];
}
return xs;
}
function toList(x) {
if (Array.isArray(x)) {
return x;
}
return [x];
}
function toSnakeCase(name) {
var intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, "$1_$2");
var insecure = intermediate.replace(/([a-z])([A-Z])/g, "$1_$2").toLowerCase();
if (insecure[0] !== "_") {
return insecure;
}
return "private" + insecure;
}
function toCamelCase(identifier) {
if (identifier.length <= 1) {
return identifier;
}
if (identifier.indexOf("_") === -1) {
return identifier;
}
return identifier.replace(/[_]+(\w|$)/g, function(m, p1) {
return p1.toUpperCase();
});
}
var _GLOBAL_CUSTOM_OBJECTS = {};
function serializeKerasObject(instance) {
if (instance === null || instance === void 0) {
return null;
}
var dict = {};
dict["className"] = instance.getClassName();
dict["config"] = instance.getConfig();
return dict;
}
function convertNDArrayScalarsInConfig(config) {
if (config == null || typeof config !== "object") {
return;
} else if (Array.isArray(config)) {
config.forEach(function(configItem) {
return convertNDArrayScalarsInConfig(configItem);
});
} else {
var fields = Object.keys(config);
for (var _i = 0, fields_1 = fields; _i < fields_1.length; _i++) {
var field = fields_1[_i];
var value = config[field];
if (value != null && typeof value === "object") {
if (!Array.isArray(value) && value["type"] === "ndarray" && typeof value["value"] === "number") {
config[field] = value["value"];
} else {
convertNDArrayScalarsInConfig(value);
}
}
}
}
}
function deserializeKerasObject(identifier, moduleObjects, customObjects, printableModuleName, fastWeightInit) {
var _a, _b, _c;
if (moduleObjects === void 0) {
moduleObjects = {};
}
if (customObjects === void 0) {
customObjects = {};
}
if (printableModuleName === void 0) {
printableModuleName = "object";
}
if (fastWeightInit === void 0) {
fastWeightInit = false;
}
if (typeof identifier === "string") {
var functionName = identifier;
var fn = void 0;
if (functionName in customObjects) {
fn = customObjects[functionName];
} else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
} else {
fn = moduleObjects[functionName];
if (fn == null) {
throw new ValueError("Unknown " + printableModuleName + ": " + identifier + ". This may be due to one of the following reasons:\n" + ("1. The " + printableModuleName + " is defined in Python, in which ") + "case it needs to be ported to TensorFlow.js or your JavaScript code.\n" + ("2. The custom " + printableModuleName + " is defined in JavaScript, ") + "but is not registered properly with tf.serialization.registerClass().");
}
}
return fn;
} else {
var config = identifier;
if (config["className"] == null || config["config"] == null) {
throw new ValueError(printableModuleName + ": Improper config format: " + (JSON.stringify(config) + ".\n") + "'className' and 'config' must set.");
}
var className = config["className"];
var cls = void 0, fromConfig = void 0;
if (className in customObjects) {
_a = customObjects[className], cls = _a[0], fromConfig = _a[1];
} else if (className in _GLOBAL_CUSTOM_OBJECTS) {
_b = _GLOBAL_CUSTOM_OBJECTS["className"], cls = _b[0], fromConfig = _b[1];
} else if (className in moduleObjects) {
_c = moduleObjects[className], cls = _c[0], fromConfig = _c[1];
}
if (cls == null) {
throw new ValueError("Unknown " + printableModuleName + ": " + className + ". This may be due to one of the following reasons:\n" + ("1. The " + printableModuleName + " is defined in Python, in which ") + "case it needs to be ported to TensorFlow.js or your JavaScript code.\n" + ("2. The custom " + printableModuleName + " is defined in JavaScript, ") + "but is not registered properly with tf.serialization.registerClass().");
}
if (fromConfig != null) {
var customObjectsCombined = {};
for (var _i = 0, _d = Object.keys(_GLOBAL_CUSTOM_OBJECTS); _i < _d.length; _i++) {
var key = _d[_i];
customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
}
for (var _e = 0, _f = Object.keys(customObjects); _e < _f.length; _e++) {
var key = _f[_e];
customObjectsCombined[key] = customObjects[key];
}
var nestedConfig = config["config"];
nestedConfig["customObjects"] = customObjectsCombined;
var backupCustomObjects = __assign({}, _GLOBAL_CUSTOM_OBJECTS);
for (var _g = 0, _h = Object.keys(customObjects); _g < _h.length; _g++) {
var key = _h[_g];
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
}
convertNDArrayScalarsInConfig(config["config"]);
var returnObj = fromConfig(cls, config["config"], customObjects, fastWeightInit);
_GLOBAL_CUSTOM_OBJECTS = __assign({}, backupCustomObjects);
return returnObj;
} else {
var backupCustomObjects = __assign({}, _GLOBAL_CUSTOM_OBJECTS);
for (var _j = 0, _k = Object.keys(customObjects); _j < _k.length; _j++) {
var key = _k[_j];
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
}
var returnObj = new cls(config["config"]);
_GLOBAL_CUSTOM_OBJECTS = __assign({}, backupCustomObjects);
return returnObj;
}
}
}
function numberCompare(a, b) {
return a < b ? -1 : a > b ? 1 : 0;
}
function reverseNumberCompare(a, b) {
return -1 * numberCompare(a, b);
}
function unique(xs) {
if (xs == null) {
return xs;
}
var out = [];
for (var _i = 0, xs_1 = xs; _i < xs_1.length; _i++) {
var x = xs_1[_i];
if (out.indexOf(x) === -1) {
out.push(x);
}
}
return out;
}
function isObjectEmpty(obj) {
if (obj == null) {
throw new ValueError("Invalid value in obj: " + JSON.stringify(obj));
}
for (var key in obj) {
if (obj.hasOwnProperty(key)) {
return false;
}
}
return true;
}
function checkStringTypeUnionValue(values, label, value) {
if (value == null) {
return;
}
if (values.indexOf(value) < 0) {
throw new ValueError(value + " is not a valid " + label + ". Valid values are " + values + " or null/undefined.");
}
}
function checkArrayTypeAndLength(x, expectedType, minLength, maxLength) {
if (minLength === void 0) {
minLength = 0;
}
if (maxLength === void 0) {
maxLength = Infinity;
}
assert(minLength >= 0);
assert(maxLength >= minLength);
return Array.isArray(x) && x.length >= minLength && x.length <= maxLength && x.every(function(e) {
return typeof e === expectedType;
});
}
function assertPositiveInteger(value, name) {
if (Array.isArray(value)) {
tfc.util.assert(value.length > 0, function() {
return name + " is unexpectedly an empty array.";
});
value.forEach(function(v, i) {
return assertPositiveInteger(v, "element " + (i + 1) + " of " + name);
});
} else {
tfc.util.assert(Number.isInteger(value) && value > 0, function() {
return "Expected " + name + " to be a positive integer, but got " + (formatAsFriendlyString(value) + ".");
});
}
}
function formatAsFriendlyString(value) {
if (value === null) {
return "null";
} else if (Array.isArray(value)) {
return "[" + value.map(function(v) {
return formatAsFriendlyString(v);
}).join(",") + "]";
} else if (typeof value === "string") {
return '"' + value + '"';
} else {
return "" + value;
}
}
function debounce(f, waitMs) {
var lastTime = tfc.util.now();
var lastResult;
var f2 = function() {
var args = [];
for (var _i = 0; _i < arguments.length; _i++) {
args[_i] = arguments[_i];
}
var now = tfc.util.now();
if (now - lastTime < waitMs) {
return lastResult;
}
lastTime = now;
lastResult = f.apply(void 0, args);
return lastResult;
};
return f2;
}
function mapActivationToFusedKernel(activationName) {
if (activationName === "relu") {
return "relu";
}
if (activationName === "linear") {
return "linear";
}
if (activationName === "elu") {
return "elu";
}
return null;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function calcL2Norms(w, axis) {
return tfc.tidy(function() {
return tfc.sqrt(tfc.sum(tfc.mul(w, w), axis, true));
});
}
var Constraint = function(_super) {
__extends(Constraint2, _super);
function Constraint2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Constraint2.prototype.getConfig = function() {
return {};
};
return Constraint2;
}(tfc.serialization.Serializable);
var MaxNorm = function(_super) {
__extends(MaxNorm2, _super);
function MaxNorm2(args) {
var _this = _super.call(this) || this;
_this.defaultMaxValue = 2;
_this.defaultAxis = 0;
_this.maxValue = args.maxValue != null ? args.maxValue : _this.defaultMaxValue;
_this.axis = args.axis != null ? args.axis : _this.defaultAxis;
return _this;
}
MaxNorm2.prototype.apply = function(w) {
var _this = this;
return tfc.tidy(function() {
var norms = calcL2Norms(w, _this.axis);
var desired = tfc.clipByValue(norms, 0, _this.maxValue);
return tfc.mul(w, tfc.div(desired, tfc.add(epsilon(), norms)));
});
};
MaxNorm2.prototype.getConfig = function() {
return {maxValue: this.maxValue, axis: this.axis};
};
MaxNorm2.className = "MaxNorm";
return MaxNorm2;
}(Constraint);
tfc.serialization.registerClass(MaxNorm);
var UnitNorm = function(_super) {
__extends(UnitNorm2, _super);
function UnitNorm2(args) {
var _this = _super.call(this) || this;
_this.defaultAxis = 0;
_this.axis = args.axis != null ? args.axis : _this.defaultAxis;
return _this;
}
UnitNorm2.prototype.apply = function(w) {
var _this = this;
return tfc.tidy(function() {
return tfc.div(w, tfc.add(epsilon(), calcL2Norms(w, _this.axis)));
});
};
UnitNorm2.prototype.getConfig = function() {
return {axis: this.axis};
};
UnitNorm2.className = "UnitNorm";
return UnitNorm2;
}(Constraint);
tfc.serialization.registerClass(UnitNorm);
var NonNeg = function(_super) {
__extends(NonNeg2, _super);
function NonNeg2() {
return _super !== null && _super.apply(this, arguments) || this;
}
NonNeg2.prototype.apply = function(w) {
return tfc.relu(w);
};
NonNeg2.className = "NonNeg";
return NonNeg2;
}(Constraint);
tfc.serialization.registerClass(NonNeg);
var MinMaxNorm = function(_super) {
__extends(MinMaxNorm2, _super);
function MinMaxNorm2(args) {
var _this = _super.call(this) || this;
_this.defaultMinValue = 0;
_this.defaultMaxValue = 1;
_this.defaultRate = 1;
_this.defaultAxis = 0;
_this.minValue = args.minValue != null ? args.minValue : _this.defaultMinValue;
_this.maxValue = args.maxValue != null ? args.maxValue : _this.defaultMaxValue;
_this.rate = args.rate != null ? args.rate : _this.defaultRate;
_this.axis = args.axis != null ? args.axis : _this.defaultAxis;
return _this;
}
MinMaxNorm2.prototype.apply = function(w) {
var _this = this;
return tfc.tidy(function() {
var norms = calcL2Norms(w, _this.axis);
var desired = tfc.add(tfc.mul(_this.rate, tfc.clipByValue(norms, _this.minValue, _this.maxValue)), tfc.mul(1 - _this.rate, norms));
return tfc.mul(w, tfc.div(desired, tfc.add(epsilon(), norms)));
});
};
MinMaxNorm2.prototype.getConfig = function() {
return {
minValue: this.minValue,
maxValue: this.maxValue,
rate: this.rate,
axis: this.axis
};
};
MinMaxNorm2.className = "MinMaxNorm";
return MinMaxNorm2;
}(Constraint);
tfc.serialization.registerClass(MinMaxNorm);
var CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
maxNorm: "MaxNorm",
minMaxNorm: "MinMaxNorm",
nonNeg: "NonNeg",
unitNorm: "UnitNorm"
};
function serializeConstraint(constraint) {
return serializeKerasObject(constraint);
}
function deserializeConstraint(config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, "constraint");
}
function getConstraint(identifier) {
if (identifier == null) {
return null;
}
if (typeof identifier === "string") {
var className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ? CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
var config = {className, config: {}};
return deserializeConstraint(config);
} else if (identifier instanceof Constraint) {
return identifier;
} else {
return deserializeConstraint(identifier);
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function maxNorm(args) {
return new MaxNorm(args);
}
function unitNorm(args) {
return new UnitNorm(args);
}
function nonNeg() {
return new NonNeg();
}
function minMaxNorm(config) {
return new MinMaxNorm(config);
}
var exports_constraints = {
__proto__: null,
maxNorm,
unitNorm,
nonNeg,
minMaxNorm
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var VALID_DATA_FORMAT_VALUES = ["channelsFirst", "channelsLast"];
var VALID_PADDING_MODE_VALUES = ["valid", "same", "causal"];
var VALID_POOL_MODE_VALUES = ["max", "avg"];
var VALID_BIDIRECTIONAL_MERGE_MODES = ["sum", "mul", "concat", "ave"];
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var nameMap = new Map();
function checkDataFormat(value) {
checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, "DataFormat", value);
}
function checkPaddingMode(value) {
checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, "PaddingMode", value);
}
function checkPoolMode(value) {
checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, "PoolMode", value);
}
var _nameScopeStack = [];
var _nameScopeDivider = "/";
function nameScope(name, fn) {
_nameScopeStack.push(name);
try {
var val = fn();
_nameScopeStack.pop();
return val;
} catch (e) {
_nameScopeStack.pop();
throw e;
}
}
function currentNameScopePrefix() {
if (_nameScopeStack.length === 0) {
return "";
} else {
return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
}
}
function getScopedTensorName(tensorName) {
if (!isValidTensorName(tensorName)) {
throw new Error("Not a valid tensor name: '" + tensorName + "'");
}
return currentNameScopePrefix() + tensorName;
}
function getUniqueTensorName(scopedName) {
if (!isValidTensorName(scopedName)) {
throw new Error("Not a valid tensor name: '" + scopedName + "'");
}
if (!nameMap.has(scopedName)) {
nameMap.set(scopedName, 0);
}
var index = nameMap.get(scopedName);
nameMap.set(scopedName, nameMap.get(scopedName) + 1);
if (index > 0) {
var result = scopedName + "_" + index;
nameMap.set(result, 1);
return result;
} else {
return scopedName;
}
}
var tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
function isValidTensorName(name) {
return !!name.match(tensorNameRegex);
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function isInteger(x) {
return x === parseInt(x.toString(), 10);
}
function arrayProd(array, begin, end) {
if (begin == null) {
begin = 0;
}
if (end == null) {
end = array.length;
}
var prod = 1;
for (var i = begin; i < end; ++i) {
prod *= array[i];
}
return prod;
}
function toArray1D(array) {
array = Array.isArray(array) ? new Float32Array(array) : array;
return tfc.tensor1d(array);
}
function min(array) {
return tfc.min(toArray1D(array)).dataSync()[0];
}
function max(array) {
return tfc.max(toArray1D(array)).dataSync()[0];
}
function range(begin, end) {
if (end < begin) {
throw new ValueError("end (" + end + ") < begin (" + begin + ") is forbidden.");
}
var out = [];
for (var i = begin; i < end; ++i) {
out.push(i);
}
return out;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function cast(x, dtype) {
return x.asType(dtype);
}
function expandDims(x, axis) {
if (axis === void 0) {
axis = -1;
}
var outShape = x.shape.slice();
if (axis < 0) {
axis = outShape.length + axis + 1;
}
outShape.splice(axis, 0, 1);
return x.reshape(outShape);
}
function repeat(x, n) {
return tfc.tidy(function() {
if (x.shape.length !== 2) {
throw new ValueError("repeat() expects a rank-2 tensor, but received a " + ("rank-" + x.shape.length + " tensor."));
}
var y = expandDims(x, 1);
return tile(y, [1, n, 1]);
});
}
function flatten(x) {
var newShape = [arrayProd(x.shape)];
return x.reshape(newShape);
}
function batchFlatten(x) {
if (x.rank <= 1) {
throw new ValueError("batchFlatten requires a minimum rank of 2. Got rank: " + x.rank + ".");
}
var newShape = [x.shape[0], arrayProd(x.shape, 1)];
return x.reshape(newShape);
}
function sliceAlongFirstAxis(array, start, size) {
return tfc.tidy(function() {
switch (array.rank) {
case 1:
return tfc.slice1d(array, start, size);
case 2:
return tfc.slice2d(array, [start, 0], [size, array.shape[1]]);
case 3:
return tfc.slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
case 4:
return tfc.slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
case 5:
return tfc.slice(array, [start, 0, 0, 0, 0], [
size,
array.shape[1],
array.shape[2],
array.shape[3],
array.shape[4]
]);
case 6:
return tfc.slice(array, [start, 0, 0, 0, 0, 0], [
size,
array.shape[1],
array.shape[2],
array.shape[3],
array.shape[4],
array.shape[5]
]);
default:
throw new ValueError("sliceAlongFirstAxis() received an unsupported tensor rank: " + ("" + array.rank));
}
});
}
function sliceAlongLastAxis(array, start, size) {
return tfc.tidy(function() {
switch (array.rank) {
case 1:
return tfc.slice1d(array, start, size);
case 2:
return tfc.slice2d(array, [0, start], [array.shape[0], size]);
case 3:
return tfc.slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
case 4:
return tfc.slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
default:
throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " + ("" + array.rank));
}
});
}
function sliceAlongAxis(array, start, size, axis) {
return tfc.tidy(function() {
switch (array.rank) {
case 1:
return tfc.slice1d(array, start, size);
case 2:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " + ("" + axis));
}
case 3:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return tfc.slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
case 3:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " + ("" + axis));
}
case 4:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return tfc.slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
case 3:
return tfc.slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
case 4:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " + ("" + axis));
}
default:
throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " + ("" + array.rank));
}
});
}
function concatenate(tensors, axis) {
if (axis === void 0) {
axis = -1;
}
var rank;
if (axis < 0) {
rank = tensors[0].rank;
if (rank !== 0) {
axis = rank;
} else {
axis = 0;
}
}
if (axis === tensors[0].rank) {
axis = -1;
}
return tfc.concat(tensors, axis);
}
function concatAlongFirstAxis(a, b) {
switch (a.rank) {
case 1:
return tfc.concat1d([a, b]);
case 2:
return tfc.concat2d([a, b], 0);
case 3:
return tfc.concat3d([a, b], 0);
case 4:
return tfc.concat4d([a, b], 0);
default:
throw new ValueError("concatAlongFirstAxis() received an unsupported " + ("tensor rank: " + a.rank));
}
}
function tile(x, n) {
if (!Array.isArray(n)) {
n = [n];
}
if (x.rank !== n.length) {
throw new ValueError("The length of input n (" + n.length + ") does not match " + ("the number of dimensions in input x (" + x.rank + ")"));
}
return tfc.tile(x, n);
}
function randomNormal(shape, mean, stddev, dtype, seed) {
if (mean === void 0) {
mean = 0;
}
if (stddev === void 0) {
stddev = 1;
}
return tfc.randomNormal(shape, mean, stddev, dtype, seed);
}
function dot(a, b, activation2, bias) {
if (a.rank < 2 || b.rank < 2) {
throw new NotImplementedError("dot requires both inputs to be rank >= 2" + (" but got x shape = " + a.shape + " and y shape = " + b.shape));
}
if (b.rank >= 3) {
var xLastDim = a.shape.slice(-1)[0];
var ySecondLastDim = b.shape.slice(-2)[0];
if (xLastDim !== ySecondLastDim) {
throw new NotImplementedError("If rank y >= 3, then the second last dim" + (" of y must equal the last dim of x but got x shape = " + a.shape + " and ") + (" y shape = " + b.shape));
}
}
if (a.rank === 2 && b.rank === 2) {
var transposeA = false;
var transposeB = false;
return tfc.fused.matMul({
a,
b,
transposeA,
transposeB,
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
activation: activation2
});
} else {
var aFirstDims = a.shape.slice();
var aLastDim = aFirstDims.pop();
a = a.reshape([-1, aLastDim]);
var bShape = b.shape.slice();
var bLastDim = bShape.pop();
var ySecondLastDim = bShape.pop();
var yOtherDims = bShape.concat([bLastDim]);
var perm = Array.from({length: b.rank}, function(_, i) {
if (i === 0) {
return b.rank - 2;
} else if (i <= b.rank - 2) {
return i - 1;
}
return i;
});
b = b.transpose(perm).reshape([ySecondLastDim, -1]);
var outputShape = aFirstDims.concat(yOtherDims);
var transposeA = false;
var transposeB = false;
return tfc.fused.matMul({
a,
b,
transposeA,
transposeB,
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
activation: activation2
}).reshape(outputShape);
}
}
function gather(reference, indices, axis) {
return tfc.tidy(function() {
if (Array.isArray(indices)) {
indices = tfc.tensor1d(indices, "int32");
} else {
indices = indices.toInt();
}
return tfc.gather(reference, indices, axis);
});
}
function square(x) {
return tfc.mul(x, x);
}
function reshapeBias(xRank, bias, dataFormat) {
var biasShape = bias.shape;
if (bias.rank !== 1 && bias.rank !== xRank) {
throw new ValueError("Unexpected bias dimensions: " + bias.rank + ("; expected it to be 1 or " + xRank));
}
if (xRank === 5) {
if (dataFormat === "channelsFirst") {
if (biasShape.length === 1) {
return bias.reshape([1, biasShape[0], 1, 1, 1]);
} else {
return bias.reshape([1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
}
} else if (dataFormat === "channelsLast") {
if (biasShape.length === 1) {
return bias.reshape([1, 1, 1, 1, biasShape[0]]);
} else {
return bias.reshape([1].concat(biasShape));
}
}
} else if (xRank === 4) {
if (dataFormat === "channelsFirst") {
if (biasShape.length === 1) {
return bias.reshape([1, biasShape[0], 1, 1]);
} else {
return bias.reshape([1, biasShape[2], biasShape[0], biasShape[1]]);
}
} else if (dataFormat === "channelsLast") {
if (biasShape.length === 1) {
return bias.reshape([1, 1, 1, biasShape[0]]);
} else {
return bias.reshape([1].concat(biasShape));
}
}
} else if (xRank === 3) {
if (dataFormat === "channelsFirst") {
if (biasShape.length === 1) {
return bias.reshape([1, biasShape[0], 1]);
} else {
return bias.reshape([1, biasShape[1], biasShape[0]]);
}
} else if (dataFormat === "channelsLast") {
if (biasShape.length === 1) {
return bias.reshape([1, 1, biasShape[0]]);
} else {
return bias.reshape([1].concat(biasShape));
}
}
} else if (xRank < 3) {
return bias;
}
throw new ValueError("Unsupported input rank by biasAdd: " + bias.rank);
}
function biasAdd(x, bias, dataFormat) {
return tfc.tidy(function() {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
return x.add(reshapeBias(x.rank, bias, dataFormat));
});
}
function elu(x, alpha) {
if (alpha === void 0) {
alpha = 1;
}
if (alpha !== 1) {
throw new NotImplementedError("Support for alpha values other than 1 (" + alpha + ") is not implemented yet.");
}
return tfc.elu(x);
}
function softsign(x) {
return tfc.tidy(function() {
return tfc.div(x, tfc.abs(x).add(1));
});
}
function dropout(x, level, noiseShape, seed) {
return tfc.tidy(function() {
return tfc.dropout(x, level, noiseShape, seed);
});
}
function hardSigmoid(x) {
return tfc.tidy(function() {
var y = tfc.add(0.5, tfc.mul(0.2, x));
return tfc.clipByValue(y, 0, 1);
});
}
function inTrainPhase(x, alt, training) {
if (training === void 0) {
training = false;
}
return training ? x() : alt();
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var VALID_FAN_MODE_VALUES = ["fanIn", "fanOut", "fanAvg"];
var VALID_DISTRIBUTION_VALUES = ["normal", "uniform", "truncatedNormal"];
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function checkFanMode(value) {
checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, "FanMode", value);
}
function checkDistribution(value) {
checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, "Distribution", value);
}
var Initializer = function(_super) {
__extends(Initializer2, _super);
function Initializer2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Initializer2.prototype.fromConfigUsesCustomObjects = function() {
return false;
};
Initializer2.prototype.getConfig = function() {
return {};
};
return Initializer2;
}(tfc.serialization.Serializable);
var Zeros = function(_super) {
__extends(Zeros2, _super);
function Zeros2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Zeros2.prototype.apply = function(shape, dtype) {
return tfc.zeros(shape, dtype);
};
Zeros2.className = "Zeros";
return Zeros2;
}(Initializer);
tfc.serialization.registerClass(Zeros);
var Ones = function(_super) {
__extends(Ones2, _super);
function Ones2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Ones2.prototype.apply = function(shape, dtype) {
return tfc.ones(shape, dtype);
};
Ones2.className = "Ones";
return Ones2;
}(Initializer);
tfc.serialization.registerClass(Ones);
var Constant = function(_super) {
__extends(Constant2, _super);
function Constant2(args) {
var _this = _super.call(this) || this;
if (typeof args !== "object") {
throw new ValueError("Expected argument of type ConstantConfig but got " + args);
}
if (args.value === void 0) {
throw new ValueError("config must have value set but got " + args);
}
_this.value = args.value;
return _this;
}
Constant2.prototype.apply = function(shape, dtype) {
var _this = this;
return tfc.tidy(function() {
return tfc.mul(tfc.scalar(_this.value), tfc.ones(shape, dtype));
});
};
Constant2.prototype.getConfig = function() {
return {
value: this.value
};
};
Constant2.className = "Constant";
return Constant2;
}(Initializer);
tfc.serialization.registerClass(Constant);
var RandomUniform = function(_super) {
__extends(RandomUniform2, _super);
function RandomUniform2(args) {
var _this = _super.call(this) || this;
_this.DEFAULT_MINVAL = -0.05;
_this.DEFAULT_MAXVAL = 0.05;
_this.minval = args.minval || _this.DEFAULT_MINVAL;
_this.maxval = args.maxval || _this.DEFAULT_MAXVAL;
_this.seed = args.seed;
return _this;
}
RandomUniform2.prototype.apply = function(shape, dtype) {
return tfc.randomUniform(shape, this.minval, this.maxval, dtype);
};
RandomUniform2.prototype.getConfig = function() {
return {minval: this.minval, maxval: this.maxval, seed: this.seed};
};
RandomUniform2.className = "RandomUniform";
return RandomUniform2;
}(Initializer);
tfc.serialization.registerClass(RandomUniform);
var RandomNormal = function(_super) {
__extends(RandomNormal2, _super);
function RandomNormal2(args) {
var _this = _super.call(this) || this;
_this.DEFAULT_MEAN = 0;
_this.DEFAULT_STDDEV = 0.05;
_this.mean = args.mean || _this.DEFAULT_MEAN;
_this.stddev = args.stddev || _this.DEFAULT_STDDEV;
_this.seed = args.seed;
return _this;
}
RandomNormal2.prototype.apply = function(shape, dtype) {
dtype = dtype || "float32";
if (dtype !== "float32" && dtype !== "int32") {
throw new NotImplementedError("randomNormal does not support dType " + dtype + ".");
}
return randomNormal(shape, this.mean, this.stddev, dtype, this.seed);
};
RandomNormal2.prototype.getConfig = function() {
return {mean: this.mean, stddev: this.stddev, seed: this.seed};
};
RandomNormal2.className = "RandomNormal";
return RandomNormal2;
}(Initializer);
tfc.serialization.registerClass(RandomNormal);
var TruncatedNormal = function(_super) {
__extends(TruncatedNormal2, _super);
function TruncatedNormal2(args) {
var _this = _super.call(this) || this;
_this.DEFAULT_MEAN = 0;
_this.DEFAULT_STDDEV = 0.05;
_this.mean = args.mean || _this.DEFAULT_MEAN;
_this.stddev = args.stddev || _this.DEFAULT_STDDEV;
_this.seed = args.seed;
return _this;
}
TruncatedNormal2.prototype.apply = function(shape, dtype) {
dtype = dtype || "float32";
if (dtype !== "float32" && dtype !== "int32") {
throw new NotImplementedError("truncatedNormal does not support dType " + dtype + ".");
}
return tfc.truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed);
};
TruncatedNormal2.prototype.getConfig = function() {
return {mean: this.mean, stddev: this.stddev, seed: this.seed};
};
TruncatedNormal2.className = "TruncatedNormal";
return TruncatedNormal2;
}(Initializer);
tfc.serialization.registerClass(TruncatedNormal);
var Identity = function(_super) {
__extends(Identity2, _super);
function Identity2(args) {
var _this = _super.call(this) || this;
_this.gain = args.gain != null ? args.gain : 1;
return _this;
}
Identity2.prototype.apply = function(shape, dtype) {
var _this = this;
return tfc.tidy(function() {
if (shape.length !== 2 || shape[0] !== shape[1]) {
throw new ValueError("Identity matrix initializer can only be used for 2D square matrices.");
} else {
return tfc.mul(_this.gain, tfc.eye(shape[0]));
}
});
};
Identity2.prototype.getConfig = function() {
return {gain: this.gain};
};
Identity2.className = "Identity";
return Identity2;
}(Initializer);
tfc.serialization.registerClass(Identity);
function computeFans(shape, dataFormat) {
if (dataFormat === void 0) {
dataFormat = "channelsLast";
}
var fanIn;
var fanOut;
checkDataFormat(dataFormat);
if (shape.length === 2) {
fanIn = shape[0];
fanOut = shape[1];
} else if ([3, 4, 5].indexOf(shape.length) !== -1) {
if (dataFormat === "channelsFirst") {
var receptiveFieldSize = arrayProd(shape, 2);
fanIn = shape[1] * receptiveFieldSize;
fanOut = shape[0] * receptiveFieldSize;
} else if (dataFormat === "channelsLast") {
var receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
fanIn = shape[shape.length - 2] * receptiveFieldSize;
fanOut = shape[shape.length - 1] * receptiveFieldSize;
}
} else {
var shapeProd = arrayProd(shape);
fanIn = Math.sqrt(shapeProd);
fanOut = Math.sqrt(shapeProd);
}
return [fanIn, fanOut];
}
var VarianceScaling = function(_super) {
__extends(VarianceScaling2, _super);
function VarianceScaling2(args) {
var _this = _super.call(this) || this;
if (args.scale < 0) {
throw new ValueError("scale must be a positive float. Got: " + args.scale);
}
_this.scale = args.scale == null ? 1 : args.scale;
_this.mode = args.mode == null ? "fanIn" : args.mode;
checkFanMode(_this.mode);
_this.distribution = args.distribution == null ? "normal" : args.distribution;
checkDistribution(_this.distribution);
_this.seed = args.seed;
return _this;
}
VarianceScaling2.prototype.apply = function(shape, dtype) {
var fans = computeFans(shape);
var fanIn = fans[0];
var fanOut = fans[1];
var scale = this.scale;
if (this.mode === "fanIn") {
scale /= Math.max(1, fanIn);
} else if (this.mode === "fanOut") {
scale /= Math.max(1, fanOut);
} else {
scale /= Math.max(1, (fanIn + fanOut) / 2);
}
if (this.distribution === "normal") {
var stddev = Math.sqrt(scale);
dtype = dtype || "float32";
if (dtype !== "float32" && dtype !== "int32") {
throw new NotImplementedError(this.getClassName() + " does not support dType " + dtype + ".");
}
return tfc.truncatedNormal(shape, 0, stddev, dtype, this.seed);
} else {
var limit = Math.sqrt(3 * scale);
return tfc.randomUniform(shape, -limit, limit, dtype);
}
};
VarianceScaling2.prototype.getConfig = function() {
return {
scale: this.scale,
mode: this.mode,
distribution: this.distribution,
seed: this.seed
};
};
VarianceScaling2.className = "VarianceScaling";
return VarianceScaling2;
}(Initializer);
tfc.serialization.registerClass(VarianceScaling);
var GlorotUniform = function(_super) {
__extends(GlorotUniform2, _super);
function GlorotUniform2(args) {
return _super.call(this, {
scale: 1,
mode: "fanAvg",
distribution: "uniform",
seed: args == null ? null : args.seed
}) || this;
}
GlorotUniform2.prototype.getClassName = function() {
return VarianceScaling.className;
};
GlorotUniform2.className = "GlorotUniform";
return GlorotUniform2;
}(VarianceScaling);
tfc.serialization.registerClass(GlorotUniform);
var GlorotNormal = function(_super) {
__extends(GlorotNormal2, _super);
function GlorotNormal2(args) {
return _super.call(this, {
scale: 1,
mode: "fanAvg",
distribution: "normal",
seed: args == null ? null : args.seed
}) || this;
}
GlorotNormal2.prototype.getClassName = function() {
return VarianceScaling.className;
};
GlorotNormal2.className = "GlorotNormal";
return GlorotNormal2;
}(VarianceScaling);
tfc.serialization.registerClass(GlorotNormal);
var HeNormal = function(_super) {
__extends(HeNormal2, _super);
function HeNormal2(args) {
return _super.call(this, {
scale: 2,
mode: "fanIn",
distribution: "normal",
seed: args == null ? null : args.seed
}) || this;
}
HeNormal2.prototype.getClassName = function() {
return VarianceScaling.className;
};
HeNormal2.className = "HeNormal";
return HeNormal2;
}(VarianceScaling);
tfc.serialization.registerClass(HeNormal);
var HeUniform = function(_super) {
__extends(HeUniform2, _super);
function HeUniform2(args) {
return _super.call(this, {
scale: 2,
mode: "fanIn",
distribution: "uniform",
seed: args == null ? null : args.seed
}) || this;
}
HeUniform2.prototype.getClassName = function() {
return VarianceScaling.className;
};
HeUniform2.className = "HeUniform";
return HeUniform2;
}(VarianceScaling);
tfc.serialization.registerClass(HeUniform);
var LeCunNormal = function(_super) {
__extends(LeCunNormal2, _super);
function LeCunNormal2(args) {
return _super.call(this, {
scale: 1,
mode: "fanIn",
distribution: "normal",
seed: args == null ? null : args.seed
}) || this;
}
LeCunNormal2.prototype.getClassName = function() {
return VarianceScaling.className;
};
LeCunNormal2.className = "LeCunNormal";
return LeCunNormal2;
}(VarianceScaling);
tfc.serialization.registerClass(LeCunNormal);
var LeCunUniform = function(_super) {
__extends(LeCunUniform2, _super);
function LeCunUniform2(args) {
return _super.call(this, {
scale: 1,
mode: "fanIn",
distribution: "uniform",
seed: args == null ? null : args.seed
}) || this;
}
LeCunUniform2.prototype.getClassName = function() {
return VarianceScaling.className;
};
LeCunUniform2.className = "LeCunNormal";
return LeCunUniform2;
}(VarianceScaling);
tfc.serialization.registerClass(LeCunUniform);
var Orthogonal = function(_super) {
__extends(Orthogonal2, _super);
function Orthogonal2(args) {
var _this = _super.call(this) || this;
_this.DEFAULT_GAIN = 1;
_this.gain = args.gain == null ? _this.DEFAULT_GAIN : args.gain;
_this.seed = args.seed;
if (_this.seed != null) {
throw new NotImplementedError("Random seed is not implemented for Orthogonal Initializer yet.");
}
return _this;
}
Orthogonal2.prototype.apply = function(shape, dtype) {
var _this = this;
return tfc.tidy(function() {
if (shape.length < 2) {
throw new NotImplementedError("Shape must be at least 2D.");
}
if (shape[0] * shape[1] > 2e3) {
console.warn("Orthogonal initializer is being called on a matrix with more " + ("than 2000 (" + shape[0] * shape[1] + ") elements: ") + "Slowness may result.");
}
var normalizedShape = shape[0] > shape[1] ? [shape[1], shape[0]] : shape;
var a = randomNormal(normalizedShape, 0, 1, "float32");
var q = tfc.linalg.gramSchmidt(a);
if (shape[0] > shape[1]) {
q = q.transpose();
}
return tfc.mul(_this.gain, q);
});
};
Orthogonal2.prototype.getConfig = function() {
return {
gain: this.gain,
seed: this.seed
};
};
Orthogonal2.className = "Orthogonal";
return Orthogonal2;
}(Initializer);
tfc.serialization.registerClass(Orthogonal);
var INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
constant: "Constant",
glorotNormal: "GlorotNormal",
glorotUniform: "GlorotUniform",
heNormal: "HeNormal",
heUniform: "HeUniform",
identity: "Identity",
leCunNormal: "LeCunNormal",
leCunUniform: "LeCunUniform",
ones: "Ones",
orthogonal: "Orthogonal",
randomNormal: "RandomNormal",
randomUniform: "RandomUniform",
truncatedNormal: "TruncatedNormal",
varianceScaling: "VarianceScaling",
zeros: "Zeros"
};
function deserializeInitializer(config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, "initializer");
}
function serializeInitializer(initializer) {
return serializeKerasObject(initializer);
}
function getInitializer(identifier) {
if (typeof identifier === "string") {
var className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
if (className === "GlorotNormal") {
return new GlorotNormal();
} else if (className === "GlorotUniform") {
return new GlorotUniform();
} else if (className === "HeNormal") {
return new HeNormal();
} else if (className === "HeUniform") {
return new HeUniform();
} else if (className === "LeCunNormal") {
return new LeCunNormal();
} else if (className === "LeCunUniform") {
return new LeCunUniform();
} else {
var config = {};
config["className"] = className;
config["config"] = {};
return deserializeInitializer(config);
}
} else if (identifier instanceof Initializer) {
return identifier;
} else {
return deserializeInitializer(identifier);
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function zeros() {
return new Zeros();
}
function ones() {
return new Ones();
}
function constant(args) {
return new Constant(args);
}
function randomUniform(args) {
return new RandomUniform(args);
}
function randomNormal$1(args) {
return new RandomNormal(args);
}
function truncatedNormal(args) {
return new TruncatedNormal(args);
}
function identity(args) {
return new Identity(args);
}
function varianceScaling(config) {
return new VarianceScaling(config);
}
function glorotUniform(args) {
return new GlorotUniform(args);
}
function glorotNormal(args) {
return new GlorotNormal(args);
}
function heNormal(args) {
return new HeNormal(args);
}
function heUniform(args) {
return new HeUniform(args);
}
function leCunNormal(args) {
return new LeCunNormal(args);
}
function leCunUniform(args) {
return new LeCunUniform(args);
}
function orthogonal(args) {
return new Orthogonal(args);
}
var exports_initializers = {
__proto__: null,
zeros,
ones,
constant,
randomUniform,
randomNormal: randomNormal$1,
truncatedNormal,
identity,
varianceScaling,
glorotUniform,
glorotNormal,
heNormal,
heUniform,
leCunNormal,
leCunUniform,
orthogonal
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var _nextUniqueTensorId = 0;
function getNextUniqueTensorId() {
return _nextUniqueTensorId++;
}
var _uidPrefixes = {};
function getUid(prefix) {
if (prefix === void 0) {
prefix = "";
}
if (!(prefix in _uidPrefixes)) {
_uidPrefixes[prefix] = 0;
}
_uidPrefixes[prefix] += 1;
return prefix + _uidPrefixes[prefix].toString();
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function isArrayOfShapes(x) {
return Array.isArray(x) && Array.isArray(x[0]);
}
function normalizeShapeList(x) {
if (x.length === 0) {
return [];
}
if (!Array.isArray(x[0])) {
return [x];
}
return x;
}
function getExactlyOneTensor(xs) {
var x;
if (Array.isArray(xs)) {
if (xs.length !== 1) {
throw new ValueError("Expected Tensor length to be 1; got " + xs.length);
}
x = xs[0];
} else {
x = xs;
}
return x;
}
function getExactlyOneShape(shapes) {
if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
if (shapes.length === 1) {
shapes = shapes;
return shapes[0];
} else {
throw new ValueError("Expected exactly 1 Shape; got " + shapes.length);
}
} else {
return shapes;
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function countParamsInWeights(weights) {
var count2 = 0;
for (var _i = 0, weights_1 = weights; _i < weights_1.length; _i++) {
var weight = weights_1[_i];
if (weight.shape.length === 0) {
count2 += 1;
} else {
count2 += weight.shape.reduce(function(a, b) {
return a * b;
});
}
}
return count2;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var DEFAULT_VARIABLE_NAME_PREFIX = "Variable";
var LayerVariable = function() {
function LayerVariable2(val, dtype, name, trainable, constraint) {
if (dtype === void 0) {
dtype = "float32";
}
if (name === void 0) {
name = DEFAULT_VARIABLE_NAME_PREFIX;
}
if (trainable === void 0) {
trainable = true;
}
if (constraint === void 0) {
constraint = null;
}
this.dtype = dtype == null ? "float32" : dtype;
this.shape = val.shape;
this.id = getNextUniqueTensorId();
name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name;
this.originalName = getScopedTensorName(name);
this.name = getUniqueTensorName(this.originalName);
this.trainable_ = trainable;
this.constraint = constraint;
this.val = tfc.variable(val, this.trainable_, this.name, this.dtype);
}
LayerVariable2.prototype.read = function() {
this.assertNotDisposed();
return this.val;
};
LayerVariable2.prototype.write = function(newVal) {
this.assertNotDisposed();
checkShapesMatch(this.val, newVal);
if (this.val.id !== newVal.id) {
this.val.assign(newVal);
if (this.constraint != null) {
this.val.assign(this.constraint.apply(this.val));
}
}
return this;
};
LayerVariable2.prototype.dispose = function() {
this.assertNotDisposed();
this.val.dispose();
};
LayerVariable2.prototype.assertNotDisposed = function() {
if (this.val.isDisposed) {
throw new Error("LayersVariable " + this.name + " is already disposed.");
}
};
Object.defineProperty(LayerVariable2.prototype, "trainable", {
get: function() {
return this.trainable_;
},
set: function(trainable) {
this.trainable_ = trainable;
this.val.trainable = trainable;
},
enumerable: true,
configurable: true
});
return LayerVariable2;
}();
function checkShapesMatch(x, y) {
if (x.shape.toString() !== y.shape.toString()) {
throw new Error("Shape mismatch: " + JSON.stringify(x.shape) + " vs. " + JSON.stringify(y.shape));
}
}
function batchGetValue(xs) {
return xs.map(function(x) {
return x.read();
});
}
function batchSetValue(variablesAndValues) {
variablesAndValues.forEach(function(variableAndValue) {
var variable = variableAndValue[0];
variable.write(variableAndValue[1]);
});
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var InputSpec = function() {
function InputSpec2(args) {
this.dtype = args.dtype;
this.shape = args.shape;
if (args.shape != null) {
this.ndim = args.shape.length;
} else {
this.ndim = args.ndim;
}
this.maxNDim = args.maxNDim;
this.minNDim = args.minNDim;
this.axes = args.axes || {};
}
return InputSpec2;
}();
var SymbolicTensor = function() {
function SymbolicTensor2(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) {
this.dtype = dtype;
this.shape = shape;
this.sourceLayer = sourceLayer;
this.inputs = inputs;
this.callArgs = callArgs;
this.outputTensorIndex = outputTensorIndex;
this.id = getNextUniqueTensorId();
if (name != null) {
this.originalName = getScopedTensorName(name);
this.name = getUniqueTensorName(this.originalName);
}
this.rank = shape.length;
}
return SymbolicTensor2;
}();
var _nextNodeID = 0;
var Node = function() {
function Node2(args, callArgs) {
this.callArgs = callArgs;
this.id = _nextNodeID++;
this.outboundLayer = args.outboundLayer;
this.inboundLayers = args.inboundLayers;
this.nodeIndices = args.nodeIndices;
this.tensorIndices = args.tensorIndices;
this.inputTensors = args.inputTensors;
this.outputTensors = args.outputTensors;
this.inputMasks = args.inputMasks;
this.outputMasks = args.outputMasks;
this.inputShapes = args.inputShapes;
this.outputShapes = args.outputShapes;
for (var _i = 0, _a = args.inboundLayers; _i < _a.length; _i++) {
var layer = _a[_i];
if (layer != null) {
layer.outboundNodes.push(this);
}
}
args.outboundLayer.inboundNodes.push(this);
}
Node2.prototype.getConfig = function() {
var inboundNames = [];
for (var _i = 0, _a = this.inboundLayers; _i < _a.length; _i++) {
var layer = _a[_i];
if (layer != null) {
inboundNames.push(layer.name);
} else {
inboundNames.push(null);
}
}
return {
outboundLayer: this.outboundLayer ? this.outboundLayer.name : null,
inboundLayers: inboundNames,
nodeIndices: this.nodeIndices,
tensorIndices: this.tensorIndices
};
};
return Node2;
}();
var _nextLayerID = 0;
var Layer = function(_super) {
__extends(Layer2, _super);
function Layer2(args) {
if (args === void 0) {
args = {};
}
var _this = _super.call(this) || this;
_this._callHook = null;
_this._addedWeightNames = [];
_this._stateful = false;
_this.id = _nextLayerID++;
_this.activityRegularizer = null;
_this.inputSpec = null;
_this.supportsMasking = false;
_this._trainableWeights = [];
_this._nonTrainableWeights = [];
_this._losses = [];
_this._updates = [];
_this._built = false;
_this.inboundNodes = [];
_this.outboundNodes = [];
var name = args.name;
if (!name) {
var prefix = _this.getClassName();
name = toSnakeCase(prefix) + "_" + getUid(prefix);
}
_this.name = name;
_this.trainable_ = args.trainable == null ? true : args.trainable;
if (args.inputShape != null || args.batchInputShape != null) {
var batchInputShape = void 0;
if (args.batchInputShape != null) {
batchInputShape = args.batchInputShape;
} else if (args.inputShape != null) {
var batchSize = null;
if (args.batchSize != null) {
batchSize = args.batchSize;
}
batchInputShape = [batchSize].concat(args.inputShape);
}
_this.batchInputShape = batchInputShape;
var dtype = args.dtype;
if (dtype == null) {
dtype = args.inputDType;
}
if (dtype == null) {
dtype = "float32";
}
_this.dtype = dtype;
}
if (args.weights != null) {
_this.initialWeights = args.weights;
} else {
_this.initialWeights = null;
}
_this._refCount = null;
_this.fastWeightInitDuringBuild = false;
return _this;
}
Layer2.nodeKey = function(layer, nodeIndex) {
return layer.name + "_ib-" + nodeIndex.toString();
};
Layer2.prototype.getNodeAtIndex = function(nodeIndex, attrName) {
if (this.inboundNodes.length === 0) {
throw new RuntimeError("The layer has never been called " + ("and thus has no defined " + attrName + "."));
}
if (this.inboundNodes.length <= nodeIndex) {
throw new ValueError("Asked to get " + attrName + " at node " + nodeIndex + ", " + ("but the layer has only " + this.inboundNodes.length + " inbound nodes."));
}
return this.inboundNodes[nodeIndex];
};
Layer2.prototype.getInputAt = function(nodeIndex) {
return singletonOrArray(this.getNodeAtIndex(nodeIndex, "input").inputTensors);
};
Layer2.prototype.getOutputAt = function(nodeIndex) {
return singletonOrArray(this.getNodeAtIndex(nodeIndex, "output").outputTensors);
};
Object.defineProperty(Layer2.prototype, "input", {
get: function() {
if (this.inboundNodes.length > 1) {
throw new AttributeError("Layer " + this.name + ' has multiple inbound nodes, hence the notion of "layer input" is ill-defined. Use `getInputAt(nodeIndex)` instead.');
} else if (this.inboundNodes.length === 0) {
throw new AttributeError("Layer " + this.name + " is not connected, no input to return.");
}
return singletonOrArray(this.getNodeAtIndex(0, "input").inputTensors);
},
enumerable: true,
configurable: true
});
Object.defineProperty(Layer2.prototype, "output", {
get: function() {
if (this.inboundNodes.length === 0) {
throw new AttributeError("Layer " + this.name + " has no inbound nodes.");
}
if (this.inboundNodes.length > 1) {
throw new AttributeError("Layer " + this.name + ' has multiple inbound nodes, hence the notion of "layer output" is ill-defined. Use `getOutputAt(nodeIndex)` instead.');
}
return singletonOrArray(this.getNodeAtIndex(0, "output").outputTensors);
},
enumerable: true,
configurable: true
});
Object.defineProperty(Layer2.prototype, "losses", {
get: function() {
return this._losses;
},
enumerable: true,
configurable: true
});
Layer2.prototype.calculateLosses = function() {
return this.losses.map(function(lossFn) {
return lossFn();
});
};
Object.defineProperty(Layer2.prototype, "updates", {
get: function() {
return this._updates;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Layer2.prototype, "built", {
get: function() {
return this._built;
},
set: function(built) {
this._built = built;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Layer2.prototype, "trainable", {
get: function() {
return this.trainable_;
},
set: function(trainable) {
this._trainableWeights.forEach(function(w) {
return w.trainable = trainable;
});
this.trainable_ = trainable;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Layer2.prototype, "trainableWeights", {
get: function() {
if (this.trainable_) {
return this._trainableWeights.filter(function(w) {
return w.trainable;
});
} else {
return [];
}
},
set: function(weights) {
this._trainableWeights = weights;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Layer2.prototype, "nonTrainableWeights", {
get: function() {
if (this.trainable) {
return this._trainableWeights.filter(function(w) {
return !w.trainable;
}).concat(this._nonTrainableWeights);
} else {
return this._trainableWeights.concat(this._nonTrainableWeights);
}
},
set: function(weights) {
this._nonTrainableWeights = weights;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Layer2.prototype, "weights", {
get: function() {
return this.trainableWeights.concat(this.nonTrainableWeights);
},
enumerable: true,
configurable: true
});
Object.defineProperty(Layer2.prototype, "stateful", {
get: function() {
return this._stateful;
},
enumerable: true,
configurable: true
});
Layer2.prototype.resetStates = function() {
if (!this.stateful) {
throw new Error("Cannot call the resetStates() method of a non-stateful Layer object.");
}
};
Layer2.prototype.assertInputCompatibility = function(inputs) {
inputs = toList(inputs);
if (this.inputSpec == null || this.inputSpec.length === 0) {
return;
}
var inputSpec = toList(this.inputSpec);
if (inputs.length !== inputSpec.length) {
throw new ValueError("Layer " + this.name + " expects " + inputSpec.length + " inputs, " + ("but it received " + inputs.length + " input tensors. ") + ("Input received: " + inputs));
}
for (var inputIndex = 0; inputIndex < inputs.length; inputIndex++) {
var x = inputs[inputIndex];
var spec = inputSpec[inputIndex];
if (spec == null) {
continue;
}
var ndim = x.rank;
if (spec.ndim != null) {
if (ndim !== spec.ndim) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + this.name + ": " + ("expected ndim=" + spec.ndim + ", found ndim=" + ndim));
}
}
if (spec.maxNDim != null) {
if (ndim > spec.maxNDim) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + this.name + (": expected max_ndim=" + spec.maxNDim + ", found ndim=" + ndim));
}
}
if (spec.minNDim != null) {
if (ndim < spec.minNDim) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + this.name + (": expected min_ndim=" + spec.minNDim + ", found ndim=" + ndim + "."));
}
}
if (spec.dtype != null) {
if (x.dtype !== spec.dtype) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + this.name + " " + (": expected dtype=" + spec.dtype + ", found dtype=" + x.dtype + "."));
}
}
if (spec.axes) {
var xShape = x.shape;
for (var key in spec.axes) {
var axis = Number(key);
var value = spec.axes[key];
var xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis];
if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + (this.name + ": expected axis " + axis + " of input shape to ") + ("have value " + value + " but got shape " + xShape + "."));
}
}
}
if (spec.shape != null) {
for (var i = 0; i < spec.shape.length; ++i) {
var specDim = spec.shape[i];
var dim = x.shape[i];
if (specDim != null && dim != null) {
if (specDim !== dim) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + (this.name + ": expected shape=" + spec.shape + ", ") + ("found shape=" + x.shape + "."));
}
}
}
}
}
};
Layer2.prototype.call = function(inputs, kwargs) {
return inputs;
};
Layer2.prototype.invokeCallHook = function(inputs, kwargs) {
if (this._callHook != null) {
this._callHook(inputs, kwargs);
}
};
Layer2.prototype.setCallHook = function(callHook) {
this._callHook = callHook;
};
Layer2.prototype.clearCallHook = function() {
this._callHook = null;
};
Layer2.prototype.apply = function(inputs, kwargs) {
var _this = this;
kwargs = kwargs || {};
this.assertNotDisposed();
var inputsList = toList(inputs);
var allAreSymbolic = true;
for (var _i = 0, inputsList_1 = inputsList; _i < inputsList_1.length; _i++) {
var input2 = inputsList_1[_i];
if (!(input2 instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
var noneAreSymbolic = true;
for (var _a = 0, inputsList_2 = inputsList; _a < inputsList_2.length; _a++) {
var input2 = inputsList_2[_a];
if (input2 instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
if (allAreSymbolic === noneAreSymbolic) {
throw new ValueError("Arguments to apply() must be all SymbolicTensors or all Tensors");
}
return nameScope(this.name, function() {
if (!_this.built) {
_this.assertInputCompatibility(inputs);
var inputShapes = [];
for (var _i2 = 0, _a2 = toList(inputs); _i2 < _a2.length; _i2++) {
var xElem = _a2[_i2];
inputShapes.push(xElem.shape);
}
_this.build(singletonOrArray(inputShapes));
_this.built = true;
if (_this.initialWeights) {
_this.setWeights(_this.initialWeights);
}
if (_this._refCount === null && noneAreSymbolic) {
_this._refCount = 1;
}
}
_this.assertInputCompatibility(inputs);
if (noneAreSymbolic) {
var output = _this.call(inputs, kwargs);
var outputList = toList(output);
var outputListCopy = [];
for (var _b = 0, outputList_1 = outputList; _b < outputList_1.length; _b++) {
var x = outputList_1[_b];
if (inputsList.indexOf(x) !== -1) {
x = x.clone();
}
outputListCopy.push(x);
}
output = singletonOrArray(outputListCopy);
if (_this.activityRegularizer != null) {
throw new NotImplementedError("Layer invocation in the presence of activity regularizer(s) is not supported yet.");
}
return output;
} else {
var inputShape = collectInputShape(inputs);
var outputShape = _this.computeOutputShape(inputShape);
var output = void 0;
var outputDType_1 = guessOutputDType(inputs);
_this.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] : inputShape);
if (outputShape != null && outputShape.length > 0 && Array.isArray(outputShape[0])) {
output = outputShape.map(function(shape, index) {
return new SymbolicTensor(outputDType_1, shape, _this, toList(inputs), kwargs, _this.name, index);
});
} else {
output = new SymbolicTensor(outputDType_1, outputShape, _this, toList(inputs), kwargs, _this.name);
}
_this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs);
_this._refCount++;
if (_this.activityRegularizer != null) {
throw new NotImplementedError("Layer invocation in the presence of activity regularizer(s) is not supported yet.");
}
return output;
}
});
};
Layer2.prototype.warnOnIncompatibleInputShape = function(inputShape) {
if (this.batchInputShape == null) {
return;
} else if (inputShape.length !== this.batchInputShape.length) {
console.warn("The rank of the input tensor provided (shape: " + (JSON.stringify(inputShape) + ") does not match that of the ") + ("batchInputShape (" + JSON.stringify(this.batchInputShape) + ") ") + ("of the layer " + this.name));
} else {
var dimMismatch_1 = false;
this.batchInputShape.forEach(function(dimension, i) {
if (dimension != null && inputShape[i] != null && inputShape[i] !== dimension) {
dimMismatch_1 = true;
}
});
if (dimMismatch_1) {
console.warn("The shape of the input tensor " + ("(" + JSON.stringify(inputShape) + ") does not ") + ("match the expectation of layer " + this.name + ": ") + ("" + JSON.stringify(this.batchInputShape)));
}
}
};
Object.defineProperty(Layer2.prototype, "outputShape", {
get: function() {
if (this.inboundNodes == null || this.inboundNodes.length === 0) {
throw new AttributeError("The layer " + this.name + " has never been called and thus has no defined output shape.");
}
var allOutputShapes = [];
for (var _i = 0, _a = this.inboundNodes; _i < _a.length; _i++) {
var node = _a[_i];
var shapeString = JSON.stringify(node.outputShapes);
if (allOutputShapes.indexOf(shapeString) === -1) {
allOutputShapes.push(shapeString);
}
}
if (allOutputShapes.length === 1) {
var outputShapes = this.inboundNodes[0].outputShapes;
if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) && outputShapes.length === 1) {
return outputShapes[0];
} else {
return outputShapes;
}
} else {
throw new AttributeError("The layer " + this.name + ' has multiple inbound nodes with different output shapes. Hence the notion of "output shape" is ill-defined for the layer.');
}
},
enumerable: true,
configurable: true
});
Layer2.prototype.countParams = function() {
if (!this.built) {
throw new RuntimeError("You tried to call countParams() on " + this.name + ", but the layer is not built yet. Build it first by calling build(batchInputShape).");
}
return countParamsInWeights(this.weights);
};
Layer2.prototype.build = function(inputShape) {
this.built = true;
};
Layer2.prototype.getWeights = function(trainableOnly) {
if (trainableOnly === void 0) {
trainableOnly = false;
}
return batchGetValue(trainableOnly ? this.trainableWeights : this.weights);
};
Layer2.prototype.setWeights = function(weights) {
var _this = this;
tfc.tidy(function() {
var params = _this.weights;
if (params.length !== weights.length) {
throw new ValueError('You called setWeights(weights) on layer "' + _this.name + '" ' + ("with a weight list of length " + weights.length + ", ") + ("but the layer was expecting " + params.length + " weights. ") + ("Provided weights: " + weights + "..."));
}
if (params.length === 0) {
return;
}
var weightValueTuples = [];
var paramValues = batchGetValue(params);
for (var i = 0; i < paramValues.length; ++i) {
var pv = paramValues[i];
var p = params[i];
var w = weights[i];
if (!tfc.util.arraysEqual(pv.shape, w.shape)) {
throw new ValueError("Layer weight shape " + pv.shape + " " + ("not compatible with provided weight shape " + w.shape));
}
weightValueTuples.push([p, w]);
}
batchSetValue(weightValueTuples);
});
};
Layer2.prototype.addWeight = function(name, shape, dtype, initializer, regularizer, trainable, constraint) {
if (this._addedWeightNames.indexOf(name) !== -1) {
throw new ValueError("Duplicate weight name " + name + " for layer " + this.name);
}
this._addedWeightNames.push(name);
if (dtype == null) {
dtype = "float32";
}
if (this.fastWeightInitDuringBuild) {
initializer = getInitializer("zeros");
}
var initValue = initializer.apply(shape, dtype);
var weight = new LayerVariable(initValue, dtype, name, trainable, constraint);
initValue.dispose();
if (regularizer != null) {
this.addLoss(function() {
return regularizer.apply(weight.read());
});
}
if (trainable == null) {
trainable = true;
}
if (trainable) {
this._trainableWeights.push(weight);
} else {
this._nonTrainableWeights.push(weight);
}
return weight;
};
Layer2.prototype.setFastWeightInitDuringBuild = function(value) {
this.fastWeightInitDuringBuild = value;
};
Layer2.prototype.addLoss = function(losses) {
var _a;
if (losses == null || Array.isArray(losses) && losses.length === 0) {
return;
}
losses = toList(losses);
if (this._losses !== void 0 && this._losses !== null) {
(_a = this.losses).push.apply(_a, losses);
}
};
Layer2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
Layer2.prototype.computeMask = function(inputs, mask) {
var _this = this;
if (!this.supportsMasking) {
if (mask != null) {
if (Array.isArray(mask)) {
mask.forEach(function(maskElement) {
if (maskElement != null) {
throw new TypeError("Layer " + _this.name + " does not support masking, but was passed an inputMask.");
}
});
} else {
throw new TypeError("Layer " + this.name + " does not support masking, but was passed an inputMask.");
}
}
return null;
}
return mask;
};
Layer2.prototype.addInboundNode = function(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs) {
if (kwargs === void 0) {
kwargs = null;
}
var inputTensorList = toList(inputTensors);
outputTensors = toList(outputTensors);
inputMasks = toList(inputMasks);
outputMasks = toList(outputMasks);
inputShapes = normalizeShapeList(inputShapes);
outputShapes = normalizeShapeList(outputShapes);
var inboundLayers = [];
var nodeIndices = [];
var tensorIndices = [];
for (var _i = 0, inputTensorList_1 = inputTensorList; _i < inputTensorList_1.length; _i++) {
var x = inputTensorList_1[_i];
inboundLayers.push(x.sourceLayer);
nodeIndices.push(x.nodeIndex);
tensorIndices.push(x.tensorIndex);
}
new Node({
outboundLayer: this,
inboundLayers,
nodeIndices,
tensorIndices,
inputTensors: inputTensorList,
outputTensors,
inputMasks,
outputMasks,
inputShapes,
outputShapes
}, kwargs);
for (var i = 0; i < outputTensors.length; i++) {
outputTensors[i].sourceLayer = this;
outputTensors[i].nodeIndex = this.inboundNodes.length - 1;
outputTensors[i].tensorIndex = i;
}
};
Layer2.prototype.getConfig = function() {
var config = {name: this.name, trainable: this.trainable};
if (this.batchInputShape != null) {
config["batchInputShape"] = this.batchInputShape;
}
if (this.dtype != null) {
config["dtype"] = this.dtype;
}
return config;
};
Layer2.prototype.disposeWeights = function() {
this.weights.forEach(function(weight) {
return weight.dispose();
});
return this.weights.length;
};
Layer2.prototype.assertNotDisposed = function() {
if (this._refCount === 0) {
throw new Error("Layer '" + this.name + "' is already disposed.");
}
};
Layer2.prototype.dispose = function() {
if (!this.built) {
throw new Error("Cannot dispose Layer " + this.name + " because it has not been built yet.");
}
if (this._refCount === null) {
throw new Error("Cannot dispose Layer " + this.name + " because it has not been used yet.");
}
this.assertNotDisposed();
var numDisposedVariables = 0;
if (--this._refCount === 0) {
numDisposedVariables = this.disposeWeights();
}
return {refCountAfterDispose: this._refCount, numDisposedVariables};
};
return Layer2;
}(tfc.serialization.Serializable);
function collectInputShape(inputTensors) {
inputTensors = toList(inputTensors);
var shapes = [];
for (var _i = 0, inputTensors_1 = inputTensors; _i < inputTensors_1.length; _i++) {
var x = inputTensors_1[_i];
shapes.push(x.shape);
}
return singletonOrArray(shapes);
}
function guessOutputDType(inputTensors) {
return "float32";
}
function getSourceInputs(tensor, layer, nodeIndex) {
if (layer == null || nodeIndex != null && nodeIndex > 0) {
layer = tensor.sourceLayer;
nodeIndex = tensor.nodeIndex;
}
if (layer.inboundNodes.length === 0) {
return [tensor];
} else {
var node = layer.inboundNodes[nodeIndex];
if (node.inboundLayers.length === 0) {
return node.inputTensors;
} else {
var sourceTensors = [];
for (var i = 0; i < node.inboundLayers.length; i++) {
var x = node.inputTensors[i];
var layer_1 = node.inboundLayers[i];
var nodeIndex_1 = node.nodeIndices[i];
var previousSources = getSourceInputs(x, layer_1, nodeIndex_1);
for (var _i = 0, previousSources_1 = previousSources; _i < previousSources_1.length; _i++) {
var x_1 = previousSources_1[_i];
if (sourceTensors.indexOf(x_1) === -1) {
sourceTensors.push(x_1);
}
}
}
return sourceTensors;
}
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var InputLayer = function(_super) {
__extends(InputLayer2, _super);
function InputLayer2(args) {
var _this = _super.call(this, {
dtype: args.dtype,
name: args.name != null ? args.name : getUid("input").toString()
}) || this;
if (args.batchSize == null) {
args.batchSize = null;
}
if (args.sparse == null) {
args.sparse = false;
}
_this.trainable = false;
_this.built = true;
_this.sparse = args.sparse;
if (args.inputShape != null && args.batchInputShape != null) {
throw new ValueError("Only provide the inputShape OR batchInputShape argument to inputLayer, not both at the same time.");
}
var batchInputShape = args.batchInputShape;
if (batchInputShape == null) {
if (args.inputShape == null) {
throw new ValueError("An InputLayer should be passed either a `batchInputShape` or an `inputShape`.");
} else {
batchInputShape = [args.batchSize].concat(args.inputShape);
}
} else {
if (args.batchSize != null) {
throw new ValueError("Cannot specify batchSize if batchInputShape is specified when creating an InputLayer.");
}
}
var dtype = args.dtype || "float32";
_this.batchInputShape = batchInputShape;
_this.dtype = dtype;
_this.inputSpec = [{shape: batchInputShape}];
var inputTensor = new SymbolicTensor(_this.dtype, _this.batchInputShape, _this, [], {}, _this.name);
inputTensor.nodeIndex = 0;
inputTensor.tensorIndex = 0;
new Node({
outboundLayer: _this,
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: [inputTensor],
outputTensors: [inputTensor],
inputMasks: [null],
outputMasks: [null],
inputShapes: [batchInputShape],
outputShapes: [batchInputShape]
});
return _this;
}
InputLayer2.prototype.apply = function(inputs, kwargs) {
throw new ValueError("Cannot pass any input to an " + ("InputLayer's apply() method. InputLayer name: " + this.name));
};
InputLayer2.prototype.dispose = function() {
return {refCountAfterDispose: this._refCount, numDisposedVariables: 0};
};
InputLayer2.prototype.getConfig = function() {
return {
batchInputShape: this.batchInputShape,
dtype: this.dtype,
sparse: this.sparse,
name: this.name
};
};
InputLayer2.className = "InputLayer";
return InputLayer2;
}(Layer);
tfc.serialization.registerClass(InputLayer);
function Input(config) {
if (config.batchShape == null && config.shape == null) {
throw new Error("Please provide to Input either a `shape` or a `batchShape` argument. Note that `shape` does not include the batch dimension.");
}
if (config.batchShape != null && config.shape != null) {
throw new ValueError("Please provide either a `shape` or `batchShape` argument to Input, but not both.");
}
var batchShape = config.batchShape;
if (config.shape != null && batchShape == null) {
batchShape = [null].concat(config.shape);
}
var dtype = config.dtype;
if (dtype == null) {
dtype = "float32";
}
var inputLayer2 = new InputLayer({
batchInputShape: batchShape,
name: config.name,
dtype,
sparse: config.sparse
});
var outputs = inputLayer2.inboundNodes[0].outputTensors;
return outputs[0];
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function resolveScalarsInLogs(logs) {
return __awaiter(this, void 0, void 0, function() {
var promises, keys, scalarsToDispose, key, value, valueScalar, values, i;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (logs == null) {
return [2];
}
promises = [];
keys = [];
scalarsToDispose = [];
for (key in logs) {
value = logs[key];
if (typeof value !== "number") {
valueScalar = value;
promises.push(valueScalar.data());
keys.push(key);
scalarsToDispose.push(valueScalar);
}
}
if (!(promises.length > 0))
return [3, 2];
return [4, Promise.all(promises)];
case 1:
values = _a.sent();
for (i = 0; i < values.length; ++i) {
logs[keys[i]] = values[i][0];
}
tfc.dispose(scalarsToDispose);
_a.label = 2;
case 2:
return [2];
}
});
});
}
function disposeTensorsInLogs(logs) {
if (logs == null) {
return;
}
for (var key in logs) {
var value = logs[key];
if (typeof value !== "number") {
value.dispose();
}
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var ModelLoggingVerbosity;
(function(ModelLoggingVerbosity2) {
ModelLoggingVerbosity2[ModelLoggingVerbosity2["SILENT"] = 0] = "SILENT";
ModelLoggingVerbosity2[ModelLoggingVerbosity2["VERBOSE"] = 1] = "VERBOSE";
})(ModelLoggingVerbosity || (ModelLoggingVerbosity = {}));
var DEFAULT_YIELD_EVERY_MS = 125;
var BaseCallback = function() {
function BaseCallback2() {
this.validationData = null;
}
BaseCallback2.prototype.setParams = function(params) {
this.params = params;
};
BaseCallback2.prototype.onEpochBegin = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2];
});
});
};
BaseCallback2.prototype.onEpochEnd = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2];
});
});
};
BaseCallback2.prototype.onBatchBegin = function(batch, logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2];
});
});
};
BaseCallback2.prototype.onBatchEnd = function(batch, logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2];
});
});
};
BaseCallback2.prototype.onTrainBegin = function(logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2];
});
});
};
BaseCallback2.prototype.onTrainEnd = function(logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2];
});
});
};
BaseCallback2.prototype.setModel = function(model2) {
};
return BaseCallback2;
}();
var CallbackList = function() {
function CallbackList2(callbacks2, queueLength) {
if (queueLength === void 0) {
queueLength = 10;
}
if (callbacks2 == null) {
callbacks2 = [];
}
this.callbacks = callbacks2;
this.queueLength = queueLength;
}
CallbackList2.prototype.append = function(callback) {
this.callbacks.push(callback);
};
CallbackList2.prototype.setParams = function(params) {
for (var _i = 0, _a = this.callbacks; _i < _a.length; _i++) {
var callback = _a[_i];
callback.setParams(params);
}
};
CallbackList2.prototype.setModel = function(model2) {
for (var _i = 0, _a = this.callbacks; _i < _a.length; _i++) {
var callback = _a[_i];
callback.setModel(model2);
}
};
CallbackList2.prototype.onEpochBegin = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
var _i, _a, callback;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (logs == null) {
logs = {};
}
_i = 0, _a = this.callbacks;
_b.label = 1;
case 1:
if (!(_i < _a.length))
return [3, 4];
callback = _a[_i];
return [4, callback.onEpochBegin(epoch, logs)];
case 2:
_b.sent();
_b.label = 3;
case 3:
_i++;
return [3, 1];
case 4:
return [2];
}
});
});
};
CallbackList2.prototype.onEpochEnd = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
var _i, _a, callback;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (logs == null) {
logs = {};
}
_i = 0, _a = this.callbacks;
_b.label = 1;
case 1:
if (!(_i < _a.length))
return [3, 4];
callback = _a[_i];
return [4, callback.onEpochEnd(epoch, logs)];
case 2:
_b.sent();
_b.label = 3;
case 3:
_i++;
return [3, 1];
case 4:
return [2];
}
});
});
};
CallbackList2.prototype.onBatchBegin = function(batch, logs) {
return __awaiter(this, void 0, void 0, function() {
var _i, _a, callback;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (logs == null) {
logs = {};
}
_i = 0, _a = this.callbacks;
_b.label = 1;
case 1:
if (!(_i < _a.length))
return [3, 4];
callback = _a[_i];
return [4, callback.onBatchBegin(batch, logs)];
case 2:
_b.sent();
_b.label = 3;
case 3:
_i++;
return [3, 1];
case 4:
return [2];
}
});
});
};
CallbackList2.prototype.onBatchEnd = function(batch, logs) {
return __awaiter(this, void 0, void 0, function() {
var _i, _a, callback;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (logs == null) {
logs = {};
}
_i = 0, _a = this.callbacks;
_b.label = 1;
case 1:
if (!(_i < _a.length))
return [3, 4];
callback = _a[_i];
return [4, callback.onBatchEnd(batch, logs)];
case 2:
_b.sent();
_b.label = 3;
case 3:
_i++;
return [3, 1];
case 4:
return [2];
}
});
});
};
CallbackList2.prototype.onTrainBegin = function(logs) {
return __awaiter(this, void 0, void 0, function() {
var _i, _a, callback;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (logs == null) {
logs = {};
}
_i = 0, _a = this.callbacks;
_b.label = 1;
case 1:
if (!(_i < _a.length))
return [3, 4];
callback = _a[_i];
return [4, callback.onTrainBegin(logs)];
case 2:
_b.sent();
_b.label = 3;
case 3:
_i++;
return [3, 1];
case 4:
return [2];
}
});
});
};
CallbackList2.prototype.onTrainEnd = function(logs) {
return __awaiter(this, void 0, void 0, function() {
var _i, _a, callback;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (logs == null) {
logs = {};
}
_i = 0, _a = this.callbacks;
_b.label = 1;
case 1:
if (!(_i < _a.length))
return [3, 4];
callback = _a[_i];
return [4, callback.onTrainEnd(logs)];
case 2:
_b.sent();
_b.label = 3;
case 3:
_i++;
return [3, 1];
case 4:
return [2];
}
});
});
};
return CallbackList2;
}();
var BaseLogger = function(_super) {
__extends(BaseLogger2, _super);
function BaseLogger2() {
return _super.call(this) || this;
}
BaseLogger2.prototype.onEpochBegin = function(epoch) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
this.seen = 0;
this.totals = {};
return [2];
});
});
};
BaseLogger2.prototype.onBatchEnd = function(batch, logs) {
return __awaiter(this, void 0, void 0, function() {
var batchSize, _loop_1, this_1, key;
var _this = this;
return __generator(this, function(_a) {
if (logs == null) {
logs = {};
}
batchSize = logs["size"] == null ? 0 : logs["size"];
this.seen += batchSize;
_loop_1 = function(key2) {
var value = logs[key2];
if (typeof value === "number") {
if (!this_1.totals.hasOwnProperty(key2)) {
this_1.totals[key2] = 0;
}
this_1.totals[key2] = this_1.totals[key2] + value * batchSize;
} else {
var oldTotalsToDispose = void 0;
if (key2 in this_1.totals) {
oldTotalsToDispose = this_1.totals[key2];
} else {
this_1.totals[key2] = 0;
}
var total = tfc.tidy(function() {
return tfc.add(_this.totals[key2], tfc.mul(value, batchSize));
});
this_1.totals[key2] = total;
if (oldTotalsToDispose != null) {
oldTotalsToDispose.dispose();
}
}
};
this_1 = this;
for (key in logs) {
_loop_1(key);
}
return [2];
});
});
};
BaseLogger2.prototype.onEpochEnd = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
var _loop_2, this_2, _i, _a, key;
var _this = this;
return __generator(this, function(_b) {
if (logs != null) {
_loop_2 = function(key2) {
if (this_2.totals[key2] == null) {
return "continue";
}
if (typeof this_2.totals[key2] === "number") {
logs[key2] = this_2.totals[key2] / this_2.seen;
} else {
tfc.tidy(function() {
var log = tfc.mul(tfc.div(1, _this.seen), _this.totals[key2]);
logs[key2] = log;
_this.totals[key2].dispose();
tfc.keep(logs[key2]);
});
}
};
this_2 = this;
for (_i = 0, _a = this.params["metrics"]; _i < _a.length; _i++) {
key = _a[_i];
_loop_2(key);
}
}
return [2];
});
});
};
return BaseLogger2;
}(BaseCallback);
var History = function(_super) {
__extends(History2, _super);
function History2() {
return _super !== null && _super.apply(this, arguments) || this;
}
History2.prototype.onTrainBegin = function(logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
this.epoch = [];
this.history = {};
return [2];
});
});
};
History2.prototype.onEpochEnd = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
var key;
return __generator(this, function(_a) {
if (logs == null) {
logs = {};
}
this.epoch.push(epoch);
for (key in logs) {
if (this.history[key] == null) {
this.history[key] = [];
}
this.history[key].push(logs[key]);
}
return [2];
});
});
};
History2.prototype.syncData = function() {
return __awaiter(this, void 0, void 0, function() {
var promises, keys, indices, key, valueArray, i, valueScalar, values, n, tensorToDispose;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
promises = [];
keys = [];
indices = [];
for (key in this.history) {
valueArray = this.history[key];
for (i = 0; i < valueArray.length; ++i) {
if (typeof valueArray[i] !== "number") {
valueScalar = valueArray[i];
promises.push(valueScalar.data());
keys.push(key);
indices.push(i);
}
}
}
return [4, Promise.all(promises)];
case 1:
values = _a.sent();
for (n = 0; n < values.length; ++n) {
tensorToDispose = this.history[keys[n]][indices[n]];
tensorToDispose.dispose();
this.history[keys[n]][indices[n]] = values[n][0];
}
return [2];
}
});
});
};
return History2;
}(BaseCallback);
var CustomCallback = function(_super) {
__extends(CustomCallback2, _super);
function CustomCallback2(args, yieldEvery) {
var _this = _super.call(this) || this;
_this.currentEpoch = 0;
_this.yieldEvery = yieldEvery || "auto";
if (_this.yieldEvery === "auto") {
_this.yieldEvery = DEFAULT_YIELD_EVERY_MS;
}
if (_this.yieldEvery === "never" && args.onYield != null) {
throw new Error("yieldEvery is `never` but you provided an `onYield` callback. Either change `yieldEvery` or remove the callback");
}
if (tfc.util.isNumber(_this.yieldEvery)) {
_this.maybeWait = debounce(_this.maybeWait.bind(_this), _this.yieldEvery);
}
_this.trainBegin = args.onTrainBegin;
_this.trainEnd = args.onTrainEnd;
_this.epochBegin = args.onEpochBegin;
_this.epochEnd = args.onEpochEnd;
_this.batchBegin = args.onBatchBegin;
_this.batchEnd = args.onBatchEnd;
_this.yield = args.onYield;
return _this;
}
CustomCallback2.prototype.maybeWait = function(epoch, batch, logs) {
return __awaiter(this, void 0, void 0, function() {
var ps;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
ps = [];
if (!(this.yield != null))
return [3, 2];
return [4, resolveScalarsInLogs(logs)];
case 1:
_a.sent();
ps.push(this.yield(epoch, batch, logs));
_a.label = 2;
case 2:
ps.push(tfc.nextFrame());
return [4, Promise.all(ps)];
case 3:
_a.sent();
return [2];
}
});
});
};
CustomCallback2.prototype.onEpochBegin = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
this.currentEpoch = epoch;
if (!(this.epochBegin != null))
return [3, 3];
return [4, resolveScalarsInLogs(logs)];
case 1:
_a.sent();
return [4, this.epochBegin(epoch, logs)];
case 2:
_a.sent();
_a.label = 3;
case 3:
return [2];
}
});
});
};
CustomCallback2.prototype.onEpochEnd = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
var ps;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
ps = [];
if (!(this.epochEnd != null))
return [3, 2];
return [4, resolveScalarsInLogs(logs)];
case 1:
_a.sent();
ps.push(this.epochEnd(epoch, logs));
_a.label = 2;
case 2:
if (this.yieldEvery === "epoch") {
ps.push(tfc.nextFrame());
}
return [4, Promise.all(ps)];
case 3:
_a.sent();
return [2];
}
});
});
};
CustomCallback2.prototype.onBatchBegin = function(batch, logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!(this.batchBegin != null))
return [3, 3];
return [4, resolveScalarsInLogs(logs)];
case 1:
_a.sent();
return [4, this.batchBegin(batch, logs)];
case 2:
_a.sent();
_a.label = 3;
case 3:
return [2];
}
});
});
};
CustomCallback2.prototype.onBatchEnd = function(batch, logs) {
return __awaiter(this, void 0, void 0, function() {
var ps;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
ps = [];
if (!(this.batchEnd != null))
return [3, 2];
return [4, resolveScalarsInLogs(logs)];
case 1:
_a.sent();
ps.push(this.batchEnd(batch, logs));
_a.label = 2;
case 2:
if (this.yieldEvery === "batch") {
ps.push(tfc.nextFrame());
} else if (tfc.util.isNumber(this.yieldEvery)) {
ps.push(this.maybeWait(this.currentEpoch, batch, logs));
}
return [4, Promise.all(ps)];
case 3:
_a.sent();
return [2];
}
});
});
};
CustomCallback2.prototype.onTrainBegin = function(logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!(this.trainBegin != null))
return [3, 3];
return [4, resolveScalarsInLogs(logs)];
case 1:
_a.sent();
return [4, this.trainBegin(logs)];
case 2:
_a.sent();
_a.label = 3;
case 3:
return [2];
}
});
});
};
CustomCallback2.prototype.onTrainEnd = function(logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!(this.trainEnd != null))
return [3, 3];
return [4, resolveScalarsInLogs(logs)];
case 1:
_a.sent();
return [4, this.trainEnd(logs)];
case 2:
_a.sent();
_a.label = 3;
case 3:
return [2];
}
});
});
};
return CustomCallback2;
}(BaseCallback);
function standardizeCallbacks(callbacks2, yieldEvery) {
if (callbacks2 == null) {
callbacks2 = {};
}
if (callbacks2 instanceof BaseCallback) {
return [callbacks2];
}
if (Array.isArray(callbacks2) && callbacks2[0] instanceof BaseCallback) {
return callbacks2;
}
var callbackConfigs = toList(callbacks2);
return callbackConfigs.map(function(callbackConfig) {
return new CustomCallback(callbackConfig, yieldEvery);
});
}
var CallbackConstructorRegistry = function() {
function CallbackConstructorRegistry2() {
}
CallbackConstructorRegistry2.registerCallbackConstructor = function(verbosityLevel, callbackConstructor) {
tfc.util.assert(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), function() {
return "Verbosity level is expected to be an integer >= 0, " + ("but got " + verbosityLevel);
});
CallbackConstructorRegistry2.checkForDuplicate(callbackConstructor);
if (CallbackConstructorRegistry2.constructors[verbosityLevel] == null) {
CallbackConstructorRegistry2.constructors[verbosityLevel] = [];
}
CallbackConstructorRegistry2.constructors[verbosityLevel].push(callbackConstructor);
};
CallbackConstructorRegistry2.checkForDuplicate = function(callbackConstructor) {
for (var levelName in CallbackConstructorRegistry2.constructors) {
var constructors = CallbackConstructorRegistry2.constructors[+levelName];
constructors.forEach(function(ctor) {
if (ctor === callbackConstructor) {
throw new ValueError("Duplicate callback constructor.");
}
});
}
};
CallbackConstructorRegistry2.clear = function() {
CallbackConstructorRegistry2.constructors = {};
};
CallbackConstructorRegistry2.createCallbacks = function(verbosityLevel) {
var constructors = [];
for (var levelName in CallbackConstructorRegistry2.constructors) {
var level = +levelName;
if (verbosityLevel >= level) {
constructors.push.apply(constructors, CallbackConstructorRegistry2.constructors[level]);
}
}
return constructors.map(function(ctor) {
return new ctor();
});
};
CallbackConstructorRegistry2.constructors = {};
return CallbackConstructorRegistry2;
}();
function configureCallbacks(callbacks2, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) {
var history = new History();
var actualCallbacks = [
new BaseLogger()
].concat(CallbackConstructorRegistry.createCallbacks(verbose));
if (callbacks2 != null) {
actualCallbacks.push.apply(actualCallbacks, callbacks2);
}
actualCallbacks.push(history);
var callbackList = new CallbackList(actualCallbacks);
callbackList.setParams({
epochs,
initialEpoch,
samples: numTrainSamples,
steps: stepsPerEpoch,
batchSize,
verbose,
doValidation,
metrics: callbackMetrics
});
return {callbackList, history};
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function deserialize(config, customObjects, fastWeightInit) {
if (customObjects === void 0) {
customObjects = {};
}
if (fastWeightInit === void 0) {
fastWeightInit = false;
}
return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, "layer", fastWeightInit);
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function l2Normalize(x, axis) {
return tfc.tidy(function() {
if (x.dtype !== "float32") {
x = x.asType("float32");
}
var squareSum = tfc.sum(square(x), axis, true);
var epsilonTensor = tfc.fill(squareSum.shape, epsilon());
var norm = tfc.sqrt(tfc.maximum(squareSum, epsilonTensor));
return tfc.div(x, norm);
});
}
function meanSquaredError(yTrue, yPred) {
return tfc.tidy(function() {
return tfc.mean(square(tfc.sub(yPred, yTrue)), -1);
});
}
function meanAbsoluteError(yTrue, yPred) {
return tfc.tidy(function() {
return tfc.mean(tfc.abs(tfc.sub(yPred, yTrue)), -1);
});
}
function meanAbsolutePercentageError(yTrue, yPred) {
return tfc.tidy(function() {
var diff = tfc.sub(yTrue, yPred);
var clippedTrue = tfc.clipByValue(tfc.abs(yTrue), epsilon(), Number.MAX_VALUE);
var absResult = tfc.abs(tfc.div(diff, clippedTrue));
return tfc.mul(100, tfc.mean(absResult, -1));
});
}
function meanSquaredLogarithmicError(yTrue, yPred) {
return tfc.tidy(function() {
var clippedPred = tfc.clipByValue(yPred, epsilon(), Number.MAX_VALUE);
var firstLog = tfc.log(tfc.add(1, clippedPred));
var clippedTrue = tfc.clipByValue(yTrue, epsilon(), Number.MAX_VALUE);
var secondLog = tfc.log(tfc.add(1, clippedTrue));
return tfc.mean(square(tfc.sub(firstLog, secondLog)), -1);
});
}
function squaredHinge(yTrue, yPred) {
return tfc.tidy(function() {
var maxResult = tfc.maximum(0, tfc.sub(1, tfc.mul(yTrue, yPred)));
return tfc.mean(square(maxResult), -1);
});
}
function hinge(yTrue, yPred) {
return tfc.tidy(function() {
var maxResult = tfc.maximum(0, tfc.sub(1, tfc.mul(yTrue, yPred)));
return tfc.mean(maxResult, -1);
});
}
function categoricalHinge(yTrue, yPred) {
return tfc.tidy(function() {
var pos = tfc.sum(tfc.mul(yTrue, yPred), -1);
var neg = tfc.max(tfc.mul(tfc.sub(1, yTrue), yPred), -1);
return tfc.maximum(0, tfc.add(1, tfc.sub(neg, pos)));
});
}
function logcosh(yTrue, yPred) {
return tfc.tidy(function() {
var log2 = Math.log(2);
var predictionDiff = tfc.sub(yPred, yTrue);
var logcoshResult = tfc.sub(tfc.add(predictionDiff, tfc.softplus(tfc.mul(-2, predictionDiff))), log2);
return tfc.mean(logcoshResult, -1);
});
}
function categoricalCrossentropy(target, output, fromLogits) {
if (fromLogits === void 0) {
fromLogits = false;
}
return tfc.tidy(function() {
if (fromLogits) {
output = tfc.softmax(output);
} else {
var outputSum = tfc.sum(output, output.shape.length - 1, true);
output = tfc.div(output, outputSum);
}
output = tfc.clipByValue(output, epsilon(), 1 - epsilon());
return tfc.neg(tfc.sum(tfc.mul(target.toFloat(), tfc.log(output)), output.shape.length - 1));
});
}
function sparseCategoricalCrossentropy(target, output, fromLogits) {
if (fromLogits === void 0) {
fromLogits = false;
}
return tfc.tidy(function() {
var flatTarget = tfc.floor(flatten(target)).toInt();
output = tfc.clipByValue(output, epsilon(), 1 - epsilon());
var outputShape = output.shape;
var oneHotTarget = tfc.oneHot(flatTarget, outputShape[outputShape.length - 1]).reshape(outputShape);
return categoricalCrossentropy(oneHotTarget, output, fromLogits);
});
}
function sigmoidCrossEntropyWithLogits(labels, logits) {
if (!tfc.util.arraysEqual(labels.shape, logits.shape)) {
throw new ValueError("logits and labels must have the same shape, but got shapes " + (JSON.stringify(labels.shape) + " and " + JSON.stringify(logits.shape)));
}
return tfc.tidy(function() {
var reluLogits = logits.relu();
var negAbsLogits = logits.abs().neg();
return reluLogits.sub(logits.mul(labels)).add(negAbsLogits.exp().log1p());
});
}
function binaryCrossentropy(yTrue, yPred) {
return tfc.tidy(function() {
var y;
y = tfc.clipByValue(yPred, epsilon(), 1 - epsilon());
y = tfc.log(tfc.div(y, tfc.sub(1, y)));
return tfc.mean(sigmoidCrossEntropyWithLogits(yTrue, y), -1);
});
}
function kullbackLeiblerDivergence(yTrue, yPred) {
return tfc.tidy(function() {
var clippedTrue = tfc.clipByValue(yTrue, epsilon(), 1);
var clippedPred = tfc.clipByValue(yPred, epsilon(), 1);
return tfc.sum(tfc.mul(yTrue, tfc.log(tfc.div(clippedTrue, clippedPred))), -1);
});
}
function poisson(yTrue, yPred) {
return tfc.tidy(function() {
var logPred = tfc.log(tfc.add(epsilon(), yPred));
return tfc.mean(tfc.sub(yPred, tfc.mul(yTrue, logPred)), -1);
});
}
function cosineProximity(yTrue, yPred) {
return tfc.tidy(function() {
var trueNormalized = l2Normalize(yTrue, -1);
var predNormalized = l2Normalize(yPred, -1);
var trueXPred = tfc.mul(trueNormalized, predNormalized);
return tfc.neg(tfc.sum(trueXPred, -1));
});
}
var lossesMap = {
meanSquaredError,
meanAbsoluteError,
meanAbsolutePercentageError,
meanSquaredLogarithmicError,
squaredHinge,
hinge,
categoricalHinge,
logcosh,
categoricalCrossentropy,
sparseCategoricalCrossentropy,
binaryCrossentropy,
kullbackLeiblerDivergence,
poisson,
cosineProximity
};
function get(identifierOrFn) {
if (typeof identifierOrFn === "string") {
if (identifierOrFn in lossesMap) {
return lossesMap[identifierOrFn];
}
var errMsg = "Unknown loss " + identifierOrFn;
if (identifierOrFn.toLowerCase().includes("softmaxcrossentropy")) {
errMsg = "Unknown loss " + identifierOrFn + '. Use "categoricalCrossentropy" as the string name for tf.losses.softmaxCrossEntropy';
}
throw new ValueError(errMsg);
} else {
return identifierOrFn;
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function binaryAccuracy(yTrue, yPred) {
return tfc.tidy(function() {
var threshold = tfc.mul(0.5, tfc.onesLike(yPred));
var yPredThresholded = cast(tfc.greater(yPred, threshold), yTrue.dtype);
return tfc.mean(tfc.equal(yTrue, yPredThresholded), -1);
});
}
function categoricalAccuracy(yTrue, yPred) {
return tfc.tidy(function() {
return cast(tfc.equal(tfc.argMax(yTrue, -1), tfc.argMax(yPred, -1)), "float32");
});
}
function truePositives(yTrue, yPred) {
return tfc.tidy(function() {
return tfc.logicalAnd(yTrue.equal(1), yPred.equal(1)).sum().cast("float32");
});
}
function falseNegatives(yTrue, yPred) {
return tfc.tidy(function() {
return tfc.logicalAnd(yTrue.equal(1), yPred.equal(0)).sum().cast("float32");
});
}
function falsePositives(yTrue, yPred) {
return tfc.tidy(function() {
return tfc.logicalAnd(yTrue.equal(0), yPred.equal(1)).sum().cast("float32");
});
}
function precision(yTrue, yPred) {
return tfc.tidy(function() {
var tp = truePositives(yTrue, yPred);
var fp = falsePositives(yTrue, yPred);
var denominator = tp.add(fp);
return tfc.where(tfc.greater(denominator, 0), tp.div(denominator), 0).cast("float32");
});
}
function recall(yTrue, yPred) {
return tfc.tidy(function() {
var tp = truePositives(yTrue, yPred);
var fn = falseNegatives(yTrue, yPred);
var denominator = tp.add(fn);
return tfc.where(tfc.greater(denominator, 0), tp.div(denominator), 0).cast("float32");
});
}
function binaryCrossentropy$1(yTrue, yPred) {
return binaryCrossentropy(yTrue, yPred);
}
function sparseCategoricalAccuracy(yTrue, yPred) {
if (yTrue.rank === yPred.rank) {
yTrue = yTrue.squeeze([yTrue.rank - 1]);
}
yPred = yPred.argMax(-1);
if (yPred.dtype !== yTrue.dtype) {
yPred = yPred.asType(yTrue.dtype);
}
return tfc.equal(yTrue, yPred).asType("float32");
}
var mse = meanSquaredError;
var MSE = meanSquaredError;
var mae = meanAbsoluteError;
var MAE = meanAbsoluteError;
var mape = meanAbsolutePercentageError;
var MAPE = meanAbsolutePercentageError;
var categoricalCrossentropy$1 = categoricalCrossentropy;
var cosine = cosineProximity;
var sparseCategoricalCrossentropy$1 = sparseCategoricalCrossentropy;
var metricsMap = {
binaryAccuracy,
categoricalAccuracy,
precision,
categoricalCrossentropy: categoricalCrossentropy$1,
sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1,
mse,
MSE,
mae,
MAE,
mape,
MAPE,
cosine
};
function get$1(identifier) {
if (typeof identifier === "string" && identifier in metricsMap) {
return metricsMap[identifier];
} else if (typeof identifier !== "string" && identifier != null) {
return identifier;
} else {
throw new ValueError("Unknown metric " + identifier);
}
}
function getLossOrMetricName(fn) {
assert(fn !== null, "Unknown LossOrMetricFn " + fn);
if (typeof fn === "string") {
return fn;
} else {
var fnName = void 0;
for (var _i = 0, _a = Object.keys(lossesMap); _i < _a.length; _i++) {
var key = _a[_i];
if (lossesMap[key] === fn) {
fnName = key;
break;
}
}
if (fnName !== void 0) {
return fnName;
}
for (var _b = 0, _c = Object.keys(metricsMap); _b < _c.length; _b++) {
var key = _c[_b];
if (metricsMap[key] === fn) {
fnName = key;
break;
}
}
if (fnName !== void 0) {
return fnName;
}
return fn.name;
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function getOptimizer(identifier) {
var optimizerMap = {
Adagrad: function() {
return tfc.train.adagrad(0.01);
},
Adadelta: function() {
return tfc.train.adadelta(1, 0.95, epsilon());
},
Adam: function() {
return tfc.train.adam(1e-3, 0.9, 0.999, epsilon());
},
Adamax: function() {
return tfc.train.adamax(2e-3, 0.9, 0.999, epsilon(), 0);
},
RMSProp: function() {
return tfc.train.rmsprop(1e-3, 0.9, 0, epsilon());
},
SGD: function() {
return tfc.train.sgd(0.01);
}
};
optimizerMap["adagrad"] = optimizerMap["Adagrad"];
optimizerMap["adadelta"] = optimizerMap["Adadelta"];
optimizerMap["adam"] = optimizerMap["Adam"];
optimizerMap["adamax"] = optimizerMap["Adamax"];
optimizerMap["rmsprop"] = optimizerMap["RMSProp"];
optimizerMap["sgd"] = optimizerMap["SGD"];
if (identifier in optimizerMap) {
return optimizerMap[identifier]();
}
throw new ValueError("Unknown Optimizer " + identifier);
}
/**
* @license
* Copyright 2019 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024;
function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize) {
if (checkSize === void 0) {
checkSize = false;
}
if (userDefinedMetadata == null || typeof userDefinedMetadata !== "object" || Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype || !plainObjectCheck(userDefinedMetadata)) {
throw new Error("User-defined metadata is expected to be a JSON object, but is not.");
}
if (checkSize) {
var out = JSON.stringify(userDefinedMetadata);
if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) {
console.warn('User-defined metadata of model "' + modelName + '" is too large in ' + ("size (length=" + out.length + " when serialized). It is not ") + "recommended to store such large objects in user-defined metadata. Please make sure its serialized length is <= " + (MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH + "."));
}
}
}
function plainObjectCheck(x) {
if (x === null) {
return true;
} else if (typeof x === "object") {
if (Object.getPrototypeOf(x) === Object.prototype) {
var keys = Object.keys(x);
for (var _i = 0, keys_1 = keys; _i < keys_1.length; _i++) {
var key = keys_1[_i];
if (typeof key !== "string") {
return false;
}
if (!plainObjectCheck(x[key])) {
return false;
}
}
return true;
} else {
if (Array.isArray(x)) {
for (var _a = 0, x_1 = x; _a < x_1.length; _a++) {
var item = x_1[_a];
if (!plainObjectCheck(item)) {
return false;
}
}
return true;
} else {
return false;
}
}
} else {
var xType = typeof x;
return xType === "string" || xType === "number" || xType === "boolean";
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function printSummary(model2, lineLength, positions, printFn) {
if (printFn === void 0) {
printFn = console.log;
}
var sequentialLike = isModelSequentialLike(model2);
var toDisplay = ["Layer (type)", "Output shape", "Param #"];
if (sequentialLike) {
lineLength = lineLength || 65;
positions = positions || [0.45, 0.85, 1];
} else {
lineLength = lineLength || 98;
positions = positions || [0.33, 0.55, 0.67, 1];
}
if (positions[positions.length - 1] <= 1) {
positions = positions.map(function(p) {
return Math.floor(lineLength * p);
});
}
var relevantNodes;
if (!sequentialLike) {
toDisplay.push("Receives inputs");
relevantNodes = [];
for (var depth in model2.nodesByDepth) {
relevantNodes.push.apply(relevantNodes, model2.nodesByDepth[depth]);
}
}
printFn("_".repeat(lineLength));
printRow(toDisplay, positions, printFn);
printFn("=".repeat(lineLength));
var layers = model2.layers;
for (var i = 0; i < layers.length; ++i) {
if (sequentialLike) {
printLayerSummary(layers[i], positions, printFn);
} else {
printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn);
}
printFn((i === layers.length - 1 ? "=" : "_").repeat(lineLength));
}
model2.checkTrainableWeightsConsistency();
var trainableCount = countTrainableParams(model2);
var nonTrainableCount = countParamsInWeights(model2.nonTrainableWeights);
printFn("Total params: " + (trainableCount + nonTrainableCount));
printFn("Trainable params: " + trainableCount);
printFn("Non-trainable params: " + nonTrainableCount);
printFn("_".repeat(lineLength));
}
function countTrainableParams(model2) {
var trainableCount;
if (model2.collectedTrainableWeights != null) {
trainableCount = countParamsInWeights(model2.collectedTrainableWeights);
} else {
trainableCount = countParamsInWeights(model2.trainableWeights);
}
return trainableCount;
}
function isModelSequentialLike(model2) {
var sequentialLike = true;
var nodesByDepth = [];
var nodes = [];
for (var depth in model2.nodesByDepth) {
nodesByDepth.push(model2.nodesByDepth[depth]);
}
for (var _i = 0, nodesByDepth_1 = nodesByDepth; _i < nodesByDepth_1.length; _i++) {
var depthNodes = nodesByDepth_1[_i];
if (depthNodes.length > 1 || depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) {
sequentialLike = false;
break;
}
nodes.push.apply(nodes, depthNodes);
}
if (sequentialLike) {
for (var _a = 0, _b = model2.layers; _a < _b.length; _a++) {
var layer = _b[_a];
var flag = false;
for (var _c = 0, _d = layer.inboundNodes; _c < _d.length; _c++) {
var node = _d[_c];
if (nodes.indexOf(node) !== -1) {
if (flag) {
sequentialLike = false;
break;
} else {
flag = true;
}
}
}
if (!sequentialLike) {
break;
}
}
}
return sequentialLike;
}
function printRow(fields, positions, printFn) {
if (printFn === void 0) {
printFn = console.log;
}
var line = "";
for (var i = 0; i < fields.length; ++i) {
if (i > 0) {
line = line.slice(0, line.length - 1) + " ";
}
line += fields[i];
line = line.slice(0, positions[i]);
line += " ".repeat(positions[i] - line.length);
}
printFn(line);
}
function printLayerSummary(layer, positions, printFn) {
var outputShape;
try {
outputShape = JSON.stringify(layer.outputShape);
} catch (err) {
outputShape = "multiple";
}
var name = layer.name;
var className = layer.getClassName();
var fields = [name + " (" + className + ")", outputShape, layer.countParams().toString()];
printRow(fields, positions, printFn);
}
function printLayerSummaryWithConnections(layer, positions, relevantNodes, printFn) {
var outputShape;
try {
outputShape = JSON.stringify(layer.outputShape);
} catch (err) {
outputShape = "multiple";
}
var connections = [];
for (var _i = 0, _a = layer.inboundNodes; _i < _a.length; _i++) {
var node = _a[_i];
if (relevantNodes != null && relevantNodes.length > 0 && relevantNodes.indexOf(node) === -1) {
continue;
}
for (var i = 0; i < node.inboundLayers.length; ++i) {
var inboundLayer = node.inboundLayers[i].name;
var inboundLayerIndex = node.nodeIndices[i];
var inboundTensorIndex = node.tensorIndices[i];
connections.push(inboundLayer + "[" + inboundLayerIndex + "][" + inboundTensorIndex + "]");
}
}
var name = layer.name;
var className = layer.getClassName();
var firstConnection = connections.length === 0 ? "" : connections[0];
var fields = [
name + " (" + className + ")",
outputShape,
layer.countParams().toString(),
firstConnection
];
printRow(fields, positions, printFn);
for (var i = 1; i < connections.length; ++i) {
printRow(["", "", "", connections[i]], positions, printFn);
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function isArrayItemInputOrOutputName(key, index, value) {
return (key === "inboundNodes" || key === "outputLayers" || key === "inputLayers") && index === 0 && typeof value === "string";
}
function convertPythonicToTs(pythonicConfig, key) {
if (pythonicConfig === null) {
return null;
} else if (typeof pythonicConfig === "string") {
return toCamelCase(pythonicConfig);
} else if (typeof pythonicConfig === "number" || typeof pythonicConfig === "boolean") {
return pythonicConfig;
} else if (pythonicConfig instanceof Array) {
var tsArray = [];
var arrayLength = pythonicConfig.length;
for (var i = 0; i < arrayLength; ++i) {
var item = pythonicConfig[i];
if (isArrayItemInputOrOutputName(key, i, item)) {
tsArray.push(item);
} else {
tsArray.push(convertPythonicToTs(item, key));
}
}
return tsArray;
} else {
var tsDict = {};
for (var _i = 0, _a = Object.keys(pythonicConfig); _i < _a.length; _i++) {
var pythonicKey = _a[_i];
var pythonicValue = pythonicConfig[pythonicKey];
if (pythonicKey === "name" && typeof pythonicValue === "string") {
tsDict[pythonicKey] = pythonicValue;
} else {
var tsKey = toCamelCase(pythonicKey);
tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey);
}
}
return tsDict;
}
}
function convertTsToPythonic(tsConfig, key) {
if (tsConfig === null || tsConfig === void 0) {
return null;
} else if (typeof tsConfig === "string") {
return toSnakeCase(tsConfig);
} else if (typeof tsConfig === "number" || typeof tsConfig === "boolean") {
return tsConfig;
} else if (tsConfig instanceof Array) {
var pyArray = [];
var arrayLength = tsConfig.length;
for (var i = 0; i < arrayLength; ++i) {
var item = tsConfig[i];
if (isArrayItemInputOrOutputName(key, i, item)) {
pyArray.push(item);
} else {
pyArray.push(convertTsToPythonic(item, key));
}
}
return pyArray;
} else {
var pyDict = {};
for (var _i = 0, _a = Object.keys(tsConfig); _i < _a.length; _i++) {
var tsKey = _a[_i];
var tsValue = tsConfig[tsKey];
var pyKey = toSnakeCase(tsKey);
if ((tsKey === "name" || tsKey === "className") && typeof tsValue === "string") {
pyDict[pyKey] = tsValue;
} else {
pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey);
}
}
return pyDict;
}
}
/** @license See the LICENSE file. */
var version = "2.7.0";
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function assertFeedCompatibility(key, val) {
if (key.dtype == null || key.dtype === val.dtype) {
return val;
}
try {
return tfc.cast(val, key.dtype);
} catch (err) {
throw new ValueError("The dtype of the feed (" + val.dtype + ") can not be cast to the dtype " + ("of the key '" + key.name + "' (" + key.dtype + ")."));
}
}
var FeedDict = function() {
function FeedDict2(feeds) {
this.id2Value = {};
this.id2Mask = {};
this.name2Id = {};
if (feeds instanceof FeedDict2) {
for (var id in feeds.id2Value) {
this.id2Value[id] = feeds.id2Value[id];
if (id in feeds.id2Mask) {
this.id2Mask[id] = feeds.id2Mask[id];
}
}
} else {
if (feeds == null) {
return;
}
for (var _i = 0, feeds_1 = feeds; _i < feeds_1.length; _i++) {
var feed = feeds_1[_i];
this.add(feed.key, feed.value);
}
}
}
FeedDict2.prototype.add = function(key, value, mask) {
if (this.id2Value[key.id] == null) {
this.id2Value[key.id] = assertFeedCompatibility(key, value);
this.name2Id[key.name] = key.id;
if (mask != null) {
this.id2Mask[key.id] = mask;
}
} else {
throw new ValueError("Duplicate key: name=" + key.name + ", id=" + key.id);
}
return this;
};
FeedDict2.prototype.addFeed = function(feed) {
this.add(feed.key, feed.value);
};
FeedDict2.prototype.hasKey = function(key) {
return this.id2Value[key.id] != null;
};
FeedDict2.prototype.names = function() {
return Object.keys(this.name2Id);
};
FeedDict2.prototype.getValue = function(key) {
if (key instanceof SymbolicTensor) {
if (this.id2Value[key.id] == null) {
throw new ValueError("Nonexistent key: " + key.name);
} else {
return this.id2Value[key.id];
}
} else {
var id = this.name2Id[key];
if (id == null) {
throw new ValueError("Feed dict has no SymbolicTensor name: " + key);
}
return this.id2Value[id];
}
};
FeedDict2.prototype.getMask = function(key) {
if (key instanceof SymbolicTensor) {
if (this.id2Value[key.id] == null) {
throw new ValueError("Nonexistent key: " + key.name);
} else {
return this.id2Mask[key.id];
}
} else {
var id = this.name2Id[key];
if (id == null) {
throw new ValueError("Feed dict has no SymbolicTensor name: " + key);
}
return this.id2Mask[id];
}
};
FeedDict2.prototype.disposeMasks = function() {
if (this.id2Mask != null) {
tfc.dispose(this.id2Mask);
}
};
return FeedDict2;
}();
var cachedSorted = {};
var cachedRecipientCounts = {};
function execute(fetches, feedDict, kwargs, probe) {
var training = kwargs == null ? false : kwargs["training"];
var arrayFetches = Array.isArray(fetches);
var fetchArray = arrayFetches ? fetches : [fetches];
var outputNames = fetchArray.map(function(t) {
return t.name;
});
var finalOutputs = [];
var feedNames = feedDict.names();
for (var _i = 0, outputNames_1 = outputNames; _i < outputNames_1.length; _i++) {
var outputName = outputNames_1[_i];
if (feedNames.indexOf(outputName) !== -1) {
finalOutputs.push(feedDict.getValue(outputName));
} else {
finalOutputs.push(null);
}
}
if (probe != null) {
probe.maxNumTensors = -Infinity;
probe.minNumTensors = Infinity;
}
var fetchAndFeedKey = outputNames.join(",") + "|" + feedDict.names().join(",");
var sorted;
var recipientCounts;
if (cachedSorted[fetchAndFeedKey] == null) {
var out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
sorted = out.sorted;
recipientCounts = out.recipientCounts;
cachedSorted[fetchAndFeedKey] = sorted;
cachedRecipientCounts[fetchAndFeedKey] = recipientCounts;
}
sorted = cachedSorted[fetchAndFeedKey];
recipientCounts = {};
if (!training) {
Object.assign(recipientCounts, cachedRecipientCounts[fetchAndFeedKey]);
}
var internalFeedDict = new FeedDict(feedDict);
for (var i = 0; i < sorted.length; ++i) {
if (probe != null) {
var numTensors = tfc.memory().numTensors;
if (numTensors > probe.maxNumTensors) {
probe.maxNumTensors = numTensors;
}
if (numTensors < probe.minNumTensors) {
probe.minNumTensors = numTensors;
}
}
var symbolic = sorted[i];
var srcLayer = symbolic.sourceLayer;
if (srcLayer instanceof InputLayer) {
continue;
}
var inputValues = [];
var inputMasks = [];
var tensorsToDispose = [];
var maskExists = false;
for (var _a = 0, _b = symbolic.inputs; _a < _b.length; _a++) {
var input2 = _b[_a];
var value = internalFeedDict.getValue(input2);
var mask = internalFeedDict.getMask(input2);
inputValues.push(value);
inputMasks.push(mask);
if (mask != null) {
maskExists = true;
}
if (!training) {
recipientCounts[input2.name]--;
if (recipientCounts[input2.name] === 0 && !feedDict.hasKey(input2) && outputNames.indexOf(input2.name) === -1 && !value.isDisposed && input2.sourceLayer.stateful !== true) {
tensorsToDispose.push(value);
}
}
}
if (maskExists) {
kwargs = kwargs || {};
kwargs["mask"] = inputMasks[0];
}
var outputTensors = toList(srcLayer.apply(inputValues, kwargs));
var outputMask = null;
if (srcLayer.supportsMasking) {
outputMask = srcLayer.computeMask(inputValues, inputMasks);
}
var layerOutputs = getNodeOutputs(symbolic);
var outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
for (var i_1 = 0; i_1 < outputSymbolicTensors.length; ++i_1) {
if (!internalFeedDict.hasKey(outputSymbolicTensors[i_1])) {
internalFeedDict.add(outputSymbolicTensors[i_1], outputTensors[i_1], Array.isArray(outputMask) ? outputMask[0] : outputMask);
}
var index = outputNames.indexOf(outputSymbolicTensors[i_1].name);
if (index !== -1) {
finalOutputs[index] = outputTensors[i_1];
}
}
if (!training) {
tfc.dispose(tensorsToDispose);
}
}
internalFeedDict.disposeMasks();
return arrayFetches ? finalOutputs : finalOutputs[0];
}
function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
tfc.util.assert(fetches != null && fetches.length > 0, function() {
return "Expected at least one fetch, got none";
});
var finalSorted = [];
var finalRecipientMap = {};
if (fetches.length === 1) {
var out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
finalSorted = out.sorted;
finalRecipientMap = out.recipientMap;
} else {
var visited = new Set();
for (var _i = 0, fetches_1 = fetches; _i < fetches_1.length; _i++) {
var fetch_1 = fetches_1[_i];
var _a = getTopologicalSortAndRecipientCountsForOneFetch(fetch_1, feedDict), sorted = _a.sorted, recipientMap = _a.recipientMap;
for (var _b = 0, sorted_1 = sorted; _b < sorted_1.length; _b++) {
var symbolicTensor = sorted_1[_b];
if (!visited.has(symbolicTensor.name)) {
finalSorted.push(symbolicTensor);
visited.add(symbolicTensor.name);
}
}
var _loop_1 = function(name_12) {
if (finalRecipientMap[name_12] == null) {
finalRecipientMap[name_12] = new Set();
}
recipientMap[name_12].forEach(function(recipient) {
return finalRecipientMap[name_12].add(recipient);
});
};
for (var name_1 in recipientMap) {
_loop_1(name_1);
}
}
}
return {
sorted: finalSorted,
recipientCounts: recipientMap2Counts(finalRecipientMap)
};
}
function recipientMap2Counts(recipientMap) {
var recipientCounts = {};
for (var name_2 in recipientMap) {
recipientCounts[name_2] = recipientMap[name_2].size;
}
return recipientCounts;
}
function getTopologicalSortAndRecipientCountsForOneFetch(fetch2, feedDict) {
var visited = new Set();
var sorted = [];
var recipientMap = {};
for (var _i = 0, _a = feedDict.names(); _i < _a.length; _i++) {
var key = _a[_i];
visited.add(key);
}
var stack = [];
var marks = [];
stack.push(fetch2);
while (stack.length > 0) {
var top_1 = stack[stack.length - 1];
if (visited.has(top_1.name)) {
stack.pop();
continue;
}
var topIsMarked = marks[marks.length - 1] === stack.length - 1;
if (top_1.inputs.length === 0 || topIsMarked) {
stack.pop();
sorted.push(top_1);
visited.add(top_1.name);
if (topIsMarked) {
marks.pop();
}
} else {
marks.push(stack.length - 1);
for (var _b = 0, _c = top_1.inputs; _b < _c.length; _b++) {
var input2 = _c[_b];
if (recipientMap[input2.name] == null) {
recipientMap[input2.name] = new Set();
}
recipientMap[input2.name].add(top_1.name);
if (visited.has(input2.name)) {
continue;
}
stack.push(input2);
}
}
}
return {sorted, recipientMap};
}
function getNodeOutputs(fetch2) {
var layerOutputs;
if (fetch2.sourceLayer.inboundNodes.length === 1) {
layerOutputs = fetch2.sourceLayer.output;
} else {
var nodeIndex = null;
for (var i = 0; i < fetch2.sourceLayer.inboundNodes.length; ++i) {
for (var _i = 0, _a = fetch2.sourceLayer.inboundNodes[i].outputTensors; _i < _a.length; _i++) {
var outputTensor = _a[_i];
if (outputTensor.id === fetch2.id) {
nodeIndex = i;
break;
}
}
}
layerOutputs = fetch2.sourceLayer.getOutputAt(nodeIndex);
}
return layerOutputs;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var Container = function(_super) {
__extends(Container2, _super);
function Container2(args) {
var _this = _super.call(this, {}) || this;
_this.containerNodes = new Set();
_this.name = args.name;
if (_this.name == null) {
var prefix = _this.getClassName().toLowerCase();
_this.name = getUid(prefix);
}
_this.supportsMasking = false;
_this.trainable_ = true;
if (Array.isArray(args.inputs)) {
_this.inputs = args.inputs.slice();
} else {
_this.inputs = [args.inputs];
}
if (Array.isArray(args.outputs)) {
_this.outputs = args.outputs.slice();
} else {
_this.outputs = [args.outputs];
}
if (unique(_this.inputs).length !== _this.inputs.length) {
throw new ValueError("The list of inputs passed to the model is redundant. All inputs should only appear once. Found: " + ("" + _this.inputs.map(function(x2) {
return x2.name;
})));
}
if (unique(_this.outputs).length !== _this.outputs.length) {
console.warn("The list of outputs passed to the model is redundant. All outputs should only appear once. Found: " + ("" + _this.outputs.map(function(x2) {
return x2.name;
})));
}
_this.inputLayers = [];
_this.inputLayersNodeIndices = [];
_this.inputLayersTensorIndices = [];
_this.outputLayers = [];
_this.outputLayersNodeIndices = [];
_this.outputLayersTensorIndices = [];
_this.layers = [];
_this.internalContainerRefs = [];
for (var _i = 0, _a = _this.outputs; _i < _a.length; _i++) {
var x = _a[_i];
var layer = x.sourceLayer;
var nodeIndex = x.nodeIndex;
var tensorIndex = x.tensorIndex;
_this.outputLayers.push(layer);
_this.outputLayersNodeIndices.push(nodeIndex);
_this.outputLayersTensorIndices.push(tensorIndex);
}
for (var _b = 0, _c = _this.inputs; _b < _c.length; _b++) {
var x = _c[_b];
var layer = x.sourceLayer;
var nodeIndex = x.nodeIndex;
var tensorIndex = x.tensorIndex;
assert(nodeIndex === 0, "input layer has >1 nodes");
assert(tensorIndex === 0, "input layer has >1 tensors");
_this.inputLayers.push(layer);
_this.inputLayersNodeIndices.push(nodeIndex);
_this.inputLayersTensorIndices.push(tensorIndex);
}
_this.inputNames = [];
_this.outputNames = [];
_this.feedInputShapes = [];
_this.feedInputNames = [];
_this.feedOutputNames = [];
for (var i = 0; i < _this.inputLayers.length; i++) {
var layer = _this.inputLayers[i];
if (!(layer instanceof InputLayer)) {
throw new TypeError("Input layers to a LayersModel must be InputLayer objects. " + ("Received inputs: " + args.inputs + ". ") + ("Input " + i + " (0-based) originates ") + ("from layer type " + layer.getClassName() + "."));
}
_this.inputNames.push(layer.name);
_this.feedInputShapes.push(layer.batchInputShape);
_this.feedInputNames.push(layer.name);
}
for (var _d = 0, _e = _this.outputLayers; _d < _e.length; _d++) {
var layer = _e[_d];
_this.outputNames.push(layer.name);
}
_this.internalInputShapes = _this.inputs.map(function(x2) {
return x2.shape;
});
_this.internalOutputShapes = _this.outputs.map(function(x2) {
return x2.shape;
});
var nodesDepths = {};
var nodeIDToNode = {};
var layersDepths = {};
var layerIDToLayer = {};
var layerIndices = {};
var nodesInDecreasingDepth = [];
var buildMapOfGraph = function(tensor, finishedNodes2, nodesInProgress2, layer2, nodeIndex2, tensorIndex2) {
if (layer2 == null || nodeIndex2 == null || tensorIndex2 == null) {
layer2 = tensor.sourceLayer;
nodeIndex2 = tensor.nodeIndex;
tensorIndex2 = tensor.tensorIndex;
}
var node2 = layer2.inboundNodes[nodeIndex2];
if (nodesInProgress2.indexOf(node2) !== -1) {
throw new RuntimeError("The tensor " + tensor.name + ' at layer "' + layer2.name + '" is part of a cycle.');
}
if (finishedNodes2.indexOf(node2) !== -1) {
return;
}
_this.containerNodes.add(Container2.nodeKey(layer2, nodeIndex2));
if (!(layer2.id in layerIndices)) {
layerIndices[layer2.id] = Object.keys(layerIndices).length;
}
if (nodesInProgress2.indexOf(node2) === -1) {
nodesInProgress2.push(node2);
}
var numInboundLayers = node2.inboundLayers.length;
for (var i2 = 0; i2 < numInboundLayers; i2++) {
var x2 = node2.inputTensors[i2];
var layer_1 = node2.inboundLayers[i2];
var nodeIndex_1 = node2.nodeIndices[i2];
var tensorIndex_1 = node2.tensorIndices[i2];
buildMapOfGraph(x2, finishedNodes2, nodesInProgress2, layer_1, nodeIndex_1, tensorIndex_1);
}
finishedNodes2.push(node2);
while (nodesInProgress2.indexOf(node2) >= 0) {
nodesInProgress2.splice(nodesInProgress2.indexOf(node2), 1);
}
nodesInDecreasingDepth.push(node2);
};
var finishedNodes = [];
var nodesInProgress = [];
for (var _f = 0, _g = _this.outputs; _f < _g.length; _f++) {
var x = _g[_f];
buildMapOfGraph(x, finishedNodes, nodesInProgress);
}
var reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse();
for (var _h = 0, reversedNodesInDecreasingDepth_1 = reversedNodesInDecreasingDepth; _h < reversedNodesInDecreasingDepth_1.length; _h++) {
var node = reversedNodesInDecreasingDepth_1[_h];
nodeIDToNode[node.id] = node;
if (!(node.id in nodesDepths)) {
nodesDepths[node.id] = 0;
}
var depth = nodesDepths[node.id];
var previousDepth = layersDepths[node.outboundLayer.id] == null ? 0 : layersDepths[node.outboundLayer.id];
depth = Math.max(depth, previousDepth);
layersDepths[node.outboundLayer.id] = depth;
layerIDToLayer[node.outboundLayer.id] = node.outboundLayer;
nodesDepths[node.id] = depth;
for (var i = 0; i < node.inboundLayers.length; i++) {
var inboundLayer = node.inboundLayers[i];
var nodeIndex = node.nodeIndices[i];
var inboundNode = inboundLayer.inboundNodes[nodeIndex];
var previousDepth_1 = nodesDepths[inboundNode.id] == null ? 0 : nodesDepths[inboundNode.id];
nodesDepths[inboundNode.id] = Math.max(depth + 1, previousDepth_1);
nodeIDToNode[inboundNode.id] = inboundNode;
}
}
var nodesByDepth = {};
for (var nodeID in nodesDepths) {
var depth = nodesDepths[nodeID];
if (!(depth in nodesByDepth)) {
nodesByDepth[depth] = [];
}
nodesByDepth[depth].push(nodeIDToNode[nodeID]);
}
var layersByDepth = {};
for (var layerID in layersDepths) {
var depth = layersDepths[layerID];
if (!(depth in layersByDepth)) {
layersByDepth[depth] = [];
}
layersByDepth[depth].push(layerIDToLayer[layerID]);
}
var depthKeys = Object.keys(layersByDepth).map(function(x2) {
return parseInt(x2, 10);
}).sort(reverseNumberCompare);
_this.layers = [];
for (var _j = 0, depthKeys_1 = depthKeys; _j < depthKeys_1.length; _j++) {
var depth = depthKeys_1[_j];
var layersForDepth = layersByDepth[depth];
layersForDepth.sort(function(a, b) {
var aIndex = layerIndices[a.id];
var bIndex = layerIndices[b.id];
if (aIndex < bIndex) {
return -1;
}
if (aIndex > bIndex) {
return 1;
}
return 0;
});
for (var _k = 0, layersForDepth_1 = layersForDepth; _k < layersForDepth_1.length; _k++) {
var layer = layersForDepth_1[_k];
if (layer instanceof Container2) {
_this.internalContainerRefs.push(layer);
}
_this.layers.push(layer);
}
}
_this.layersByDepth = layersByDepth;
depthKeys = Object.keys(nodesByDepth).map(function(x2) {
return parseInt(x2, 10);
}).sort(reverseNumberCompare);
var computableTensors = _this.inputs.slice();
var layersWithCompleteInput = [];
for (var _l = 0, depthKeys_2 = depthKeys; _l < depthKeys_2.length; _l++) {
var depth = depthKeys_2[_l];
for (var _m = 0, _o = nodesByDepth[depth]; _m < _o.length; _m++) {
var node = _o[_m];
var layer = node.outboundLayer;
if (layer != null) {
for (var _p = 0, _q = node.inputTensors; _p < _q.length; _p++) {
var x = _q[_p];
if (computableTensors.indexOf(x) === -1) {
throw new RuntimeError("Graph disconnected: cannot obtain value for tensor " + x + (' at layer "' + layer.name + '". ') + "The following previous layers were accessed without " + ("issue: " + layersWithCompleteInput));
}
}
for (var _r = 0, _s = node.outputTensors; _r < _s.length; _r++) {
var x = _s[_r];
computableTensors.push(x);
}
layersWithCompleteInput.push(layer.name);
}
}
}
_this.nodesByDepth = nodesByDepth;
var allNames = _this.layers.map(function(x2) {
return x2.name;
});
var _loop_1 = function(name_12) {
var numOccurrences = allNames.filter(function(x2) {
return x2 === name_12;
}).length;
if (numOccurrences !== 1) {
throw new RuntimeError('The name "' + name_12 + '" is used ' + numOccurrences + " times in the model. All layer names should be unique. Layer names: " + JSON.stringify(allNames));
}
};
for (var _t = 0, allNames_1 = allNames; _t < allNames_1.length; _t++) {
var name_1 = allNames_1[_t];
_loop_1(name_1);
}
_this.outboundNodes = [];
_this.inboundNodes = [];
new Node({
outboundLayer: _this,
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: _this.inputs,
outputTensors: _this.outputs,
inputMasks: _this.inputs.map(function(x2) {
return null;
}),
outputMasks: _this.outputs.map(function(x2) {
return null;
}),
inputShapes: _this.inputs.map(function(x2) {
return x2.shape;
}),
outputShapes: _this.outputs.map(function(x2) {
return x2.shape;
})
});
_this.built = true;
_this._refCount = 1;
return _this;
}
Container2.prototype.assertNotDisposed = function() {
if (this._refCount === 0) {
throw new Error("Container '" + this.name + "' is already disposed.");
}
};
Container2.prototype.dispose = function() {
this.assertNotDisposed();
var result = {refCountAfterDispose: null, numDisposedVariables: 0};
if (--this._refCount === 0) {
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
result.numDisposedVariables += layer.dispose().numDisposedVariables;
}
for (var _b = 0, _c = this.internalContainerRefs; _b < _c.length; _b++) {
var container = _c[_b];
result.numDisposedVariables += container.dispose().numDisposedVariables;
}
}
result.refCountAfterDispose = this._refCount;
return result;
};
Object.defineProperty(Container2.prototype, "trainable", {
get: function() {
return this.trainable_;
},
set: function(trainable) {
this.layers.forEach(function(layer) {
layer._trainableWeights.forEach(function(w) {
return w.trainable = trainable;
});
});
this.trainable_ = trainable;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Container2.prototype, "trainableWeights", {
get: function() {
if (this._trainableWeights.length > 0) {
throw new ValueError("Container instance unexpectedly contains _trainableWeights.The trainable weights of a Container are a union of the trainable weights of its consituent Layers. Its own _trainableWeights must remain an empty Array.");
}
if (!this.trainable) {
return [];
}
var weights = [];
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
weights = weights.concat(layer.trainableWeights);
}
return weights;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Container2.prototype, "nonTrainableWeights", {
get: function() {
var weights = [];
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
weights.push.apply(weights, layer.nonTrainableWeights);
}
if (!this.trainable) {
var trainableWeights = [];
for (var _b = 0, _c = this.layers; _b < _c.length; _b++) {
var layer = _c[_b];
trainableWeights.push.apply(trainableWeights, layer.trainableWeights);
}
return trainableWeights.concat(weights);
}
return weights;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Container2.prototype, "weights", {
get: function() {
return this.trainableWeights.concat(this.nonTrainableWeights);
},
enumerable: true,
configurable: true
});
Container2.prototype.loadWeights = function(weights, strict) {
if (strict === void 0) {
strict = true;
}
var nameToWeight = {};
var totalWeightsCount = 0;
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
for (var _b = 0, _c = layer.weights; _b < _c.length; _b++) {
var weight = _c[_b];
if (nameToWeight[weight.originalName] != null) {
throw new ValueError("Duplicate weight name: " + weight.originalName);
}
nameToWeight[weight.originalName] = weight;
totalWeightsCount++;
}
}
var weightValueTuples = [];
for (var name_2 in weights) {
var validatedName = name_2;
if (nameToWeight[name_2] == null) {
var tokens = name_2.split("/");
var shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]);
validatedName = shortenNameArray.join("/");
}
if (nameToWeight[validatedName] != null) {
weightValueTuples.push([nameToWeight[validatedName], weights[name_2]]);
} else if (strict) {
throw new ValueError("Provided weight data has no target variable: " + name_2);
}
delete nameToWeight[validatedName];
}
if (strict) {
var unsetNames = [];
for (var name_3 in nameToWeight) {
unsetNames.push(name_3);
}
if (unsetNames.length > 0) {
throw new ValueError(unsetNames.length + " of " + totalWeightsCount + " weights are not set: " + ("" + unsetNames));
}
}
batchSetValue(weightValueTuples);
};
Container2.prototype.updatedConfig = function() {
var theConfig = this.getConfig();
var modelConfig = {};
modelConfig["className"] = this.getClassName();
modelConfig["config"] = theConfig;
modelConfig["kerasVersion"] = "tfjs-layers " + version;
modelConfig["backend"] = "TensorFlow.js";
return modelConfig;
};
Container2.prototype.toJSON = function(unused, returnString) {
if (returnString === void 0) {
returnString = true;
}
var modelConfig = convertTsToPythonic(this.updatedConfig());
return returnString ? JSON.stringify(modelConfig) : modelConfig;
};
Container2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = toList(inputs);
var feedDict = new FeedDict();
for (var i = 0; i < _this.inputs.length; ++i) {
feedDict.add(_this.inputs[i], inputs[i]);
}
return execute(_this.outputs, feedDict, kwargs);
});
};
Container2.prototype.computeMask = function(inputs, mask) {
var _this = this;
return tfc.tidy(function() {
inputs = toList(inputs);
var masks;
if (mask == null) {
masks = pyListRepeat(null, inputs.length);
} else {
masks = toList(mask);
}
return _this.runInternalGraph(inputs, masks)[1];
});
};
Container2.prototype.computeOutputShape = function(inputShape) {
var inputShapes = normalizeShapeList(inputShape);
if (inputShapes.length !== this.inputLayers.length) {
throw new ValueError("Invalid inputShape argument " + inputShape + ": " + ("model has " + this.inputLayers.length + " tensor inputs."));
}
var layersToOutputShapes = {};
for (var i = 0; i < inputShapes.length; i++) {
var layer = this.inputLayers[i];
var inputShape_1 = inputShapes[i];
var shapeKey = layer.name + "_0_0";
layersToOutputShapes[shapeKey] = inputShape_1;
}
var depthKeys = Object.keys(this.nodesByDepth).map(function(x) {
return parseInt(x, 10);
}).sort(reverseNumberCompare);
if (depthKeys.length > 1) {
for (var _i = 0, depthKeys_3 = depthKeys; _i < depthKeys_3.length; _i++) {
var depth = depthKeys_3[_i];
var nodes = this.nodesByDepth[depth];
for (var _a = 0, nodes_1 = nodes; _a < nodes_1.length; _a++) {
var node = nodes_1[_a];
var layer = node.outboundLayer;
if (this.inputLayers.map(function(x) {
return x.id;
}).indexOf(layer.id) !== -1) {
continue;
}
var inputShapes_1 = [];
for (var j = 0; j < node.inboundLayers.length; j++) {
var inboundLayer = node.inboundLayers[j];
var nodeIndex_2 = node.nodeIndices[j];
var tensorIndex = node.tensorIndices[j];
var shapeKey = inboundLayer.name + "_" + nodeIndex_2 + "_" + tensorIndex;
var inputShape_2 = layersToOutputShapes[shapeKey];
inputShapes_1.push(inputShape_2);
}
var outputShape = layer.computeOutputShape(singletonOrArray(inputShapes_1));
var outputShapes_1 = normalizeShapeList(outputShape);
var nodeIndex = layer.inboundNodes.indexOf(node);
for (var j = 0; j < outputShapes_1.length; j++) {
var shapeKey = layer.name + "_" + nodeIndex + "_" + j;
layersToOutputShapes[shapeKey] = outputShapes_1[j];
}
}
}
}
var outputShapes = [];
var outputShapeKeys = [];
for (var i = 0; i < this.outputLayers.length; i++) {
var layer = this.outputLayers[i];
var nodeIndex = this.outputLayersNodeIndices[i];
var tensorIndex = this.outputLayersTensorIndices[i];
var shapeKey = layer.name + "_" + nodeIndex + "_" + tensorIndex;
outputShapeKeys.push(shapeKey);
}
for (var i = 0; i < outputShapeKeys.length; i++) {
var key = outputShapeKeys[i];
assert(key in layersToOutputShapes);
outputShapes.push(layersToOutputShapes[key]);
}
return singletonOrArray(outputShapes);
};
Container2.prototype.runInternalGraph = function(inputs, masks) {
if (masks == null) {
masks = pyListRepeat(null, inputs.length);
}
var tensorMap = {};
for (var i = 0; i < this.inputs.length; ++i) {
var x = this.inputs[i];
var y = inputs[i];
var mask = masks[i];
tensorMap[x.id] = [y, mask];
}
var depthKeys = Object.keys(this.nodesByDepth).map(function(x2) {
return parseInt(x2, 10);
}).sort(reverseNumberCompare);
for (var _i = 0, depthKeys_4 = depthKeys; _i < depthKeys_4.length; _i++) {
var depth = depthKeys_4[_i];
var nodes = this.nodesByDepth[depth];
for (var _a = 0, nodes_2 = nodes; _a < nodes_2.length; _a++) {
var node = nodes_2[_a];
var layer = node.outboundLayer;
var referenceInputTensors = node.inputTensors;
var referenceOutputTensors = node.outputTensors;
var computedData = new Array();
for (var _b = 0, referenceInputTensors_1 = referenceInputTensors; _b < referenceInputTensors_1.length; _b++) {
var x = referenceInputTensors_1[_b];
if (x.id in tensorMap) {
computedData.push(tensorMap[x.id]);
}
}
if (computedData.length === referenceInputTensors.length) {
var kwargs = {};
var computedTensors = void 0;
var computedMasks = void 0;
var outputTensors_1 = void 0;
var outputMasks_1 = void 0;
if (node.callArgs != null) {
kwargs = node.callArgs;
}
if (computedData.length === 1) {
var _c = computedData[0], computedTensor = _c[0], computedMask = _c[1];
if (kwargs["mask"] == null) {
kwargs["mask"] = computedMask;
}
outputTensors_1 = toList(layer.call(computedTensor, kwargs));
outputMasks_1 = toList(layer.computeMask(computedTensor, computedMask));
computedTensors = [computedTensor];
computedMasks = [computedMask];
} else {
computedTensors = computedData.map(function(x2) {
return x2[0];
});
computedMasks = computedData.map(function(x2) {
return x2[1];
});
if (kwargs["mask"] == null) {
kwargs["mask"] = computedMasks;
}
outputTensors_1 = toList(layer.call(computedTensors, kwargs));
outputMasks_1 = toList(layer.computeMask(computedTensors, computedMasks));
}
if (layer.activityRegularizer) {
throw new NotImplementedError("LayersModel invocation with concrete Tensor value(s) in the presence of activity regularizer(s) is not supported yet.");
}
for (var i = 0; i < referenceOutputTensors.length; ++i) {
var x = referenceOutputTensors[i];
var y = outputTensors_1[i];
var mask = outputMasks_1[i];
tensorMap[x.id] = [y, mask];
}
}
}
}
var outputTensors = [];
var outputMasks = [];
var outputShapes = [];
for (var _d = 0, _e = this.outputs; _d < _e.length; _d++) {
var x = _e[_d];
assert(x.id in tensorMap, "Could not compute output " + x.name + " : " + x.id);
var _f = tensorMap[x.id], tensor = _f[0], mask = _f[1];
outputShapes.push(tensor.shape);
outputTensors.push(tensor);
outputMasks.push(mask);
}
return [outputTensors, outputMasks, outputShapes];
};
Container2.prototype.buildNodeConversionMap = function(layers) {
var nodeConversionMap = {};
var keptNodes;
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
keptNodes = layer instanceof Container2 ? 1 : 0;
for (var originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
var nodeKey = Container2.nodeKey(layer, originalNodeIndex);
if (this.containerNodes.has(nodeKey)) {
nodeConversionMap[nodeKey] = keptNodes;
keptNodes += 1;
}
}
}
return nodeConversionMap;
};
Container2.prototype.getLayer = function(name, index) {
if (index != null) {
if (this.layers.length <= index) {
throw new ValueError("Was asked to retrieve layer at index " + index + ", but model only " + ("has " + this.layers.length + " layer(s)."));
} else {
return this.layers[index];
}
} else {
if (name == null) {
throw new ValueError("Provide either a layer name or layer index");
}
}
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
if (layer.name === name) {
return layer;
}
}
throw new ValueError("No such layer: " + name);
};
Container2.prototype.calculateLosses = function() {
var _this = this;
return tfc.tidy(function() {
var losses = [];
for (var _i = 0, _a = _this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
for (var nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) {
var nodeKey = Container2.nodeKey(layer, nodeIndex);
if (_this.containerNodes.has(nodeKey)) {
losses.push.apply(losses, layer.calculateLosses());
}
}
}
return losses;
});
};
Container2.prototype.getConfig = function() {
var config = {name: this.name};
var nodeConversionMap = this.buildNodeConversionMap(this.layers);
var layerConfigs = [];
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
var layerClassName = layer.getClassName();
var layerConfig = layer.getConfig();
var filteredInboundNodes = [];
for (var originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
var node = layer.inboundNodes[originalNodeIndex];
var nodeKey = Container2.nodeKey(layer, originalNodeIndex);
var kwargs = {};
if (this.containerNodes.has(nodeKey)) {
if (node.callArgs) {
try {
JSON.stringify(node.callArgs);
kwargs = node.callArgs;
} catch (err) {
console.warn("Layer " + layer.name + " was passed non-serializable keyword arguments: " + (node.callArgs + ". They will not be included ") + "in the serialized model (and thus will be missing at deserialization time).");
kwargs = {};
}
}
if (node.inboundLayers.length > 0) {
var nodeData = [];
for (var i = 0; i < node.inboundLayers.length; i++) {
var inboundLayer = node.inboundLayers[i];
var nodeIndex = node.nodeIndices[i];
var tensorIndex = node.tensorIndices[i];
var nodeKey_1 = Container2.nodeKey(inboundLayer, nodeIndex);
var newNodeIndex = nodeConversionMap[nodeKey_1];
if (newNodeIndex == null) {
newNodeIndex = 0;
}
nodeData.push([inboundLayer.name, newNodeIndex, tensorIndex, kwargs]);
}
filteredInboundNodes.push(nodeData);
}
}
}
var dict = {};
dict["name"] = layer.name;
dict["className"] = layerClassName;
dict["config"] = layerConfig;
dict["inboundNodes"] = filteredInboundNodes;
layerConfigs.push(dict);
}
config["layers"] = layerConfigs;
var modelInputs = [];
for (var i = 0; i < this.inputLayers.length; i++) {
var layer = this.inputLayers[i];
var nodeIndex = this.inputLayersNodeIndices[i];
var nodeKey = Container2.nodeKey(layer, nodeIndex);
if (!this.containerNodes.has(nodeKey)) {
continue;
}
var newNodeIndex = nodeConversionMap[nodeKey];
if (newNodeIndex === null || newNodeIndex === void 0) {
newNodeIndex = 0;
}
var tensorIndex = this.inputLayersTensorIndices[i];
modelInputs.push([layer.name, newNodeIndex, tensorIndex]);
}
config["inputLayers"] = modelInputs;
var modelOutputs = [];
for (var i = 0; i < this.outputLayers.length; i++) {
var layer = this.outputLayers[i];
var nodeIndex = this.outputLayersNodeIndices[i];
var nodeKey = Container2.nodeKey(layer, nodeIndex);
if (!this.containerNodes.has(nodeKey)) {
continue;
}
var newNodeIndex = nodeConversionMap[nodeKey];
if (newNodeIndex === null || newNodeIndex === void 0) {
newNodeIndex = 0;
}
var tensorIndex = this.outputLayersTensorIndices[i];
modelOutputs.push([layer.name, newNodeIndex, tensorIndex]);
}
config["outputLayers"] = modelOutputs;
return config;
};
Container2.fromConfig = function(cls, config, customObjects, fastWeightInit) {
if (fastWeightInit === void 0) {
fastWeightInit = false;
}
var createdLayers = {};
var unprocessedNodes = {};
function addUnprocessedNode(layer2, nodeData2) {
if (!(layer2.name in unprocessedNodes)) {
unprocessedNodes[layer2.name] = [nodeData2];
} else {
unprocessedNodes[layer2.name].push(nodeData2);
}
}
function processNode(layer2, nodeData2) {
var inputTensors2 = [];
var kwargs;
for (var _i2 = 0, nodeData_1 = nodeData2; _i2 < nodeData_1.length; _i2++) {
var inputData = nodeData_1[_i2];
var inboundLayerName = inputData[0];
var inboundNodeIndex = inputData[1];
var inboundTensorIndex = inputData[2];
kwargs = inputData[3] == null ? {} : inputData[3];
if (!(inboundLayerName in createdLayers)) {
addUnprocessedNode(layer2, nodeData2);
return;
}
var inboundLayer = createdLayers[inboundLayerName];
if (inboundLayer.inboundNodes.length <= inboundNodeIndex) {
addUnprocessedNode(layer2, nodeData2);
return;
}
var inboundNode = inboundLayer.inboundNodes[inboundNodeIndex];
inputTensors2.push(inboundNode.outputTensors[inboundTensorIndex]);
}
if (inputTensors2.length > 0) {
layer2.apply(singletonOrArray(inputTensors2), kwargs);
}
}
function processLayer(layerData2) {
var layerName2 = layerData2["name"];
var layer2 = deserialize(layerData2, config["customObjects"] != null ? config["customObjects"] : {});
layer2.setFastWeightInitDuringBuild(fastWeightInit);
createdLayers[layerName2] = layer2;
var inboundNodesData = layerData2["inboundNodes"];
inboundNodesData.forEach(function(nodeData2) {
if (!(nodeData2 instanceof Array)) {
throw new ValueError("Corrupted configuration, expected array for nodeData: " + nodeData2);
}
addUnprocessedNode(layer2, nodeData2);
});
}
var name = config["name"];
var layersFromConfig = config["layers"];
for (var _i = 0, layersFromConfig_1 = layersFromConfig; _i < layersFromConfig_1.length; _i++) {
var layerData = layersFromConfig_1[_i];
processLayer(layerData);
}
while (!isObjectEmpty(unprocessedNodes)) {
for (var _a = 0, layersFromConfig_2 = layersFromConfig; _a < layersFromConfig_2.length; _a++) {
var layerData = layersFromConfig_2[_a];
var layer = createdLayers[layerData["name"]];
if (layer.name in unprocessedNodes) {
var currentUnprocessedNodesForLayer = unprocessedNodes[layer.name];
delete unprocessedNodes[layer.name];
for (var _b = 0, currentUnprocessedNodesForLayer_1 = currentUnprocessedNodesForLayer; _b < currentUnprocessedNodesForLayer_1.length; _b++) {
var nodeData = currentUnprocessedNodesForLayer_1[_b];
processNode(layer, nodeData);
}
}
}
}
var inputTensors = [];
var outputTensors = [];
var inputLayersFromConfig = config["inputLayers"];
for (var _c = 0, inputLayersFromConfig_1 = inputLayersFromConfig; _c < inputLayersFromConfig_1.length; _c++) {
var layerData = inputLayersFromConfig_1[_c];
var layerName = layerData[0];
var nodeIndex = layerData[1];
var tensorIndex = layerData[2];
assert(layerName in createdLayers);
var layer = createdLayers[layerName];
var layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
inputTensors.push(layerOutputTensors[tensorIndex]);
}
var outputLayersFromConfig = config["outputLayers"];
for (var _d = 0, outputLayersFromConfig_1 = outputLayersFromConfig; _d < outputLayersFromConfig_1.length; _d++) {
var layerData = outputLayersFromConfig_1[_d];
var layerName = layerData[0];
var nodeIndex = layerData[1];
var tensorIndex = layerData[2];
assert(layerName in createdLayers);
var layer = createdLayers[layerName];
var layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
outputTensors.push(layerOutputTensors[tensorIndex]);
}
return new cls({inputs: inputTensors, outputs: outputTensors, name});
};
Object.defineProperty(Container2.prototype, "stateful", {
get: function() {
if (this._stateful) {
throw new ValueError("Container instance unexpectedly has _stateful = true. The statefulness of a Container is determined by the Layers it contains. Its _stateful property must remain the default false.");
}
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
if (layer.stateful) {
return true;
}
}
return false;
},
enumerable: true,
configurable: true
});
Container2.prototype.resetStates = function() {
var _this = this;
tfc.tidy(function() {
_this.layers.forEach(function(layer) {
if (layer.stateful) {
layer.resetStates();
}
});
});
};
return Container2;
}(Layer);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
var numOutputs = outputNames.length;
if (xWeight == null || Array.isArray(xWeight) && xWeight.length === 0) {
return outputNames.map(function(name) {
return null;
});
}
if (numOutputs === 1) {
if (Array.isArray(xWeight) && xWeight.length === 1) {
return xWeight;
} else if (typeof xWeight === "object" && outputNames[0] in xWeight) {
return [xWeight[outputNames[0]]];
} else {
return [xWeight];
}
}
if (Array.isArray(xWeight)) {
if (xWeight.length !== numOutputs) {
throw new Error("Provided " + weightType + " is an array of " + xWeight.length + " " + ("element(s), but the model has " + numOutputs + " outputs. ") + "Make sure a set of weights is provided for each model output.");
}
return xWeight;
} else if (typeof xWeight === "object" && Object.keys(xWeight).length > 0 && typeof xWeight[Object.keys(xWeight)[0]] === "object") {
var output_1 = [];
outputNames.forEach(function(outputName) {
if (outputName in xWeight) {
output_1.push(xWeight[outputName]);
} else {
output_1.push(null);
}
});
return output_1;
} else {
throw new Error("The model has multiple (" + numOutputs + ") outputs, " + ("so " + weightType + " must be either an array with ") + (numOutputs + " elements or an object with " + outputNames + " keys. ") + ("Provided " + weightType + " not understood: " + JSON.stringify(xWeight)));
}
}
function standardizeClassWeights(classWeight, outputNames) {
return standardizeSampleOrClassWeights(classWeight, outputNames, "classWeight");
}
function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) {
return __awaiter(this, void 0, void 0, function() {
var yClasses, yClassIndices, _a, _b, classSampleWeight_1;
return __generator(this, function(_c) {
switch (_c.label) {
case 0:
if (sampleWeight != null || sampleWeightMode != null) {
throw new Error("Support sampleWeight is not implemented yet");
}
if (!(classWeight != null))
return [3, 2];
yClasses = tfc.tidy(function() {
if (y.shape.length === 1) {
return y.clone();
} else if (y.shape.length === 2) {
if (y.shape[1] > 1) {
var axis = 1;
return y.argMax(axis);
} else if (y.shape[1] === 1) {
return y.reshape([y.shape[0]]);
} else {
throw new Error("Encountered unexpected last-dimension size (" + y.shape[1] + ") during handling of class weights. The size is expected to be >= 1.");
}
} else {
throw new Error("Unexpected rank of target (y) tensor (" + y.rank + ") during handling of class weights. The rank is expected to be 1 or 2.");
}
});
_b = (_a = Array).from;
return [4, yClasses.data()];
case 1:
yClassIndices = _b.apply(_a, [_c.sent()]);
tfc.dispose(yClasses);
classSampleWeight_1 = [];
yClassIndices.forEach(function(classIndex) {
if (classWeight[classIndex] == null) {
throw new Error("classWeight must contain all classes in the training data. " + ("The class " + classIndex + " exists in the data but not in ") + "classWeight");
} else {
classSampleWeight_1.push(classWeight[classIndex]);
}
});
return [2, tfc.tensor1d(classSampleWeight_1, "float32")];
case 2:
return [2, null];
}
});
});
}
function computeWeightedLoss(losses, sampleWeights) {
return tfc.mul(losses, sampleWeights);
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var DEFAULT_VALIDATION_BATCH_SIZE = 32;
function standardizeDataIteratorOutput(model2, iteratorOut) {
var xs;
var ys;
var iteratorOutObj = iteratorOut;
xs = iteratorOutObj["xs"];
ys = iteratorOutObj["ys"];
tfc.util.assert(xs != null && ys != null, function() {
return "A Dataset iterator for fitDataset() is expected to generate objects of the form `{xs: xVal, ys: yVal}`, where the two values may be `tf.Tensor`, an array of Tensors, or a map of string to Tensor. The provided Dataset instead generates " + ("" + iteratorOut);
});
var flattenedXs = flattenTensorOrArrayOrMap("input", model2.inputNames, xs);
var flattenedYs = flattenTensorOrArrayOrMap("output", model2.outputNames, ys);
var batchSize = flattenedXs[0].shape[0];
tfc.util.assert(flattenedXs.length === model2.inputs.length, function() {
return "LayersModel has " + model2.inputs.length + " inputs, but the dataset " + ("provides " + flattenedXs.length + " inputs. (Expected input keys: ") + (JSON.stringify(model2.inputNames) + ")");
});
tfc.util.assert(flattenedYs.length === model2.outputs.length, function() {
return "LayersModel has " + model2.outputs.length + " outputs, but the dataset " + ("provides " + flattenedYs.length + " outputs. (Expected output keys: ") + (JSON.stringify(model2.outputNames) + ")");
});
var _loop_1 = function(xIndex2) {
tfc.util.assert(flattenedXs[xIndex2].shape[0] === batchSize, function() {
return "Batch size mismatch: input " + (model2.inputNames[xIndex2] + " has " + flattenedXs[xIndex2].shape[0] + "; ") + ("expected " + batchSize + " based on input " + model2.inputNames[0] + ".");
});
};
for (var xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
_loop_1(xIndex);
}
var _loop_2 = function(yIndex2) {
tfc.util.assert(flattenedYs[yIndex2].shape[0] === batchSize, function() {
return "Batch size mismatch: output " + (model2.outputNames[yIndex2] + " has " + flattenedYs[yIndex2].shape[0] + "; ") + ("expected " + batchSize + " based on input " + model2.inputNames[0] + ".");
});
};
for (var yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
_loop_2(yIndex);
}
return {xs: flattenedXs, ys: flattenedYs};
}
function flattenTensorOrArrayOrMap(inputOrOutput, names, values) {
if (values instanceof tfc.Tensor) {
return [values];
} else if (Array.isArray(values)) {
tfc.util.assert(values.length === names.length, function() {
return "Received an array of " + values.length + " Tensors, but expected " + names.length + " to match the " + inputOrOutput + " keys " + names + ".";
});
return values;
} else {
var result = [];
for (var _i = 0, names_1 = names; _i < names_1.length; _i++) {
var name_1 = names_1[_i];
if (values[name_1] == null) {
throw new ValueError("The feature data generated by the dataset lacks the required " + (inputOrOutput + " key '" + name_1 + "'."));
}
result.push(values[name_1]);
}
return result;
}
}
function standardizeTensorValidationData(data) {
if (data.length === 3) {
throw new NotImplementedError("Validation with sample weights is not implemented yet.");
}
return {xs: data[0], ys: data[1]};
}
function fitDataset(model2, dataset, args) {
return __awaiter(this, void 0, void 0, function() {
var hasBatchesPerEpoch, doValidation, valXs, valYs, validationData, trainFunction, outLabels, callbackMetrics, callbacks2, verbose, _a, callbackList, history_1, epoch, dataIterator, epochLogs, stepsDone, batchIndex, iteratorOut, _b, xs, ys, batchLogs, sampleWeights, standardClassWeights, i, _c, _d, ins, outs, i, label, out, valOuts, _e, i;
return __generator(this, function(_f) {
switch (_f.label) {
case 0:
hasBatchesPerEpoch = args.batchesPerEpoch != null;
tfc.util.assert(model2.optimizer != null, function() {
return "You must compile a model before training/testing. Use LayersModel.compile(modelCompileConfig).";
});
tfc.util.assert(args != null, function() {
return "For fitDataset(), the 2nd argument (config) is required, but it is not provided in this call.";
});
tfc.util.assert(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), function() {
return "For fitDataset(), config.epochs is expected to be a positive " + ("integer, but got " + args.epochs);
});
tfc.util.assert(!hasBatchesPerEpoch || args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch), function() {
return "For fitDataset(), config.batchesPerEpoch is expected to be a " + ("positive integer if specified, but got " + args.batchesPerEpoch);
});
tfc.util.assert(args["validationSplit"] == null, function() {
return "`validationSplit` is not supported by `fitDataset()`. Use validationData instead.";
});
if (model2.isTraining) {
throw new Error("Cannot start training because another fit() call is ongoing.");
}
model2.isTraining = true;
_f.label = 1;
case 1:
_f.trys.push([1, , 26, 27]);
doValidation = args.validationData != null;
valXs = void 0;
valYs = void 0;
if (doValidation) {
if (isDatasetObject(args.validationData)) {
tfc.util.assert(args.validationBatches == null || args.validationBatches > 0 && Number.isInteger(args.validationBatches), function() {
return "For fitDataset() with dataset-based validation, config.validationBatches is expected not to be provided, or to be a positive integer, " + ("but got " + args.validationBatches);
});
} else {
validationData = standardizeTensorValidationData(args.validationData);
valXs = validationData.xs;
valYs = validationData.ys;
}
}
trainFunction = model2.makeTrainFunction();
outLabels = model2.getDedupedMetricsNames();
callbackMetrics = void 0;
if (doValidation) {
callbackMetrics = outLabels.slice().concat(outLabels.map(function(n) {
return "val_" + n;
}));
} else {
callbackMetrics = outLabels.slice();
}
callbacks2 = standardizeCallbacks(args.callbacks, args.yieldEvery);
verbose = args.verbose == null ? 1 : args.verbose;
_a = configureCallbacks(callbacks2, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null, doValidation, callbackMetrics), callbackList = _a.callbackList, history_1 = _a.history;
callbackList.setModel(model2);
model2.history = history_1;
return [4, callbackList.onTrainBegin()];
case 2:
_f.sent();
model2.stopTraining_ = false;
epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
return [4, dataset.iterator()];
case 3:
dataIterator = _f.sent();
_f.label = 4;
case 4:
if (!(epoch < args.epochs))
return [3, 23];
epochLogs = {};
return [4, callbackList.onEpochBegin(epoch)];
case 5:
_f.sent();
stepsDone = 0;
batchIndex = 0;
if (!!hasBatchesPerEpoch)
return [3, 7];
return [4, dataset.iterator()];
case 6:
dataIterator = _f.sent();
_f.label = 7;
case 7:
if (!(hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true))
return [3, 21];
return [4, dataIterator.next()];
case 8:
iteratorOut = _f.sent();
if (hasBatchesPerEpoch && iteratorOut.done) {
console.warn("You provided `batchesPerEpoch` as " + (args.batchesPerEpoch + ", ") + "but your dataset iterator ran out of data after " + (stepsDone + " batches; ") + "interrupting training. Make sure that your dataset can generate at least `batchesPerEpoch * epochs` batches (in this case, " + (args.batchesPerEpoch * args.epochs + " batches). ") + "You may need to use the repeat() function when building your dataset.");
return [3, 21];
}
if (!(iteratorOut.value != null))
return [3, 15];
_b = standardizeDataIteratorOutput(model2, iteratorOut.value), xs = _b.xs, ys = _b.ys;
batchLogs = {};
batchLogs["batch"] = batchIndex;
batchLogs["size"] = xs[0].shape[0];
return [4, callbackList.onBatchBegin(batchIndex, batchLogs)];
case 9:
_f.sent();
sampleWeights = [];
if (!(args.classWeight != null))
return [3, 13];
standardClassWeights = standardizeClassWeights(args.classWeight, model2.outputNames);
i = 0;
_f.label = 10;
case 10:
if (!(i < standardClassWeights.length))
return [3, 13];
_d = (_c = sampleWeights).push;
return [4, standardizeWeights(ys[i], null, standardClassWeights[i])];
case 11:
_d.apply(_c, [_f.sent()]);
_f.label = 12;
case 12:
++i;
return [3, 10];
case 13:
ins = xs.concat(ys).concat(sampleWeights);
outs = trainFunction(ins);
tfc.dispose(ins);
for (i = 0; i < outLabels.length; ++i) {
label = outLabels[i];
out = outs[i];
batchLogs[label] = out;
tfc.keep(out);
}
return [4, callbackList.onBatchEnd(batchIndex, batchLogs)];
case 14:
_f.sent();
disposeTensorsInLogs(batchLogs);
batchIndex++;
stepsDone++;
_f.label = 15;
case 15:
if (!(hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch : iteratorOut.done))
return [3, 20];
if (!doValidation)
return [3, 19];
valOuts = void 0;
if (!isDatasetObject(args.validationData))
return [3, 17];
_e = toList;
return [4, model2.evaluateDataset(args.validationData, {batches: args.validationBatches})];
case 16:
valOuts = _e.apply(void 0, [_f.sent()]);
return [3, 18];
case 17:
valOuts = toList(model2.evaluate(valXs, valYs, {
batchSize: args.validationBatchSize == null ? DEFAULT_VALIDATION_BATCH_SIZE : args.validationBatchSize,
verbose: 0
}));
_f.label = 18;
case 18:
for (i = 0; i < model2.metricsNames.length; ++i) {
epochLogs["val_" + model2.metricsNames[i]] = valOuts[i];
}
_f.label = 19;
case 19:
return [3, 21];
case 20:
if (model2.stopTraining_) {
return [3, 21];
}
return [3, 7];
case 21:
return [4, callbackList.onEpochEnd(epoch, epochLogs)];
case 22:
_f.sent();
epoch++;
if (model2.stopTraining_) {
return [3, 23];
}
return [3, 4];
case 23:
return [4, callbackList.onTrainEnd()];
case 24:
_f.sent();
return [4, model2.history.syncData()];
case 25:
_f.sent();
return [2, model2.history];
case 26:
model2.isTraining = false;
return [7];
case 27:
return [2];
}
});
});
}
function getStepsPerEpoch(dataset, args) {
var stepsPerEpoch = null;
if (args.batchesPerEpoch != null) {
stepsPerEpoch = args.batchesPerEpoch;
} else if (Number.isFinite(dataset.size)) {
stepsPerEpoch = dataset.size;
}
return stepsPerEpoch;
}
function isDatasetObject(dataset) {
return typeof dataset.iterator === "function";
}
function isLazyIteratorObject(iterator) {
return typeof iterator.next === "function";
}
function evaluateDataset(model2, dataset, args) {
return __awaiter(this, void 0, void 0, function() {
var hasBatches, f, outs, dataIterator, _a, numExamples, batch, _loop_3, state_1, i, oldScalar;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
args = args || {};
hasBatches = args.batches != null;
f = model2.testFunction;
outs = [];
if (args.verbose > 0) {
throw new NotImplementedError("Verbose mode is not implemented yet.");
}
tfc.util.assert(!hasBatches || args.batches > 0 && Number.isInteger(args.batches), function() {
return "Test loop expects `batches` to be a positive integer, but " + ("received " + JSON.stringify(args.batches));
});
if (!isLazyIteratorObject(dataset))
return [3, 1];
_a = dataset;
return [3, 3];
case 1:
return [4, dataset.iterator()];
case 2:
_a = _b.sent();
_b.label = 3;
case 3:
dataIterator = _a;
numExamples = 0;
batch = 0;
_loop_3 = function() {
var iteratorOut;
return __generator(this, function(_a2) {
switch (_a2.label) {
case 0:
return [4, dataIterator.next()];
case 1:
iteratorOut = _a2.sent();
outs = tfc.tidy(function() {
if (iteratorOut.value) {
var _a3 = standardizeDataIteratorOutput(model2, iteratorOut.value), xs = _a3.xs, ys = _a3.ys;
var xsAndYs_1 = xs.concat(ys);
var batchOuts = tfc.tidy(function() {
return f(xsAndYs_1);
});
tfc.dispose(xsAndYs_1);
if (batch === 0) {
for (var i2 = 0; i2 < batchOuts.length; ++i2) {
outs.push(tfc.scalar(0));
}
}
var batchSize_1 = xsAndYs_1[0].shape[0];
var _loop_4 = function(i3) {
var batchOut = batchOuts[i3];
var oldScalar2 = outs[i3];
outs[i3] = tfc.tidy(function() {
return tfc.add(outs[i3], tfc.mul(batchSize_1, batchOut));
});
if (batch > 0) {
tfc.dispose(oldScalar2);
}
};
for (var i2 = 0; i2 < batchOuts.length; ++i2) {
_loop_4(i2);
}
tfc.dispose(batchOuts);
numExamples += batchSize_1;
++batch;
}
return outs;
});
if (iteratorOut.done) {
if (hasBatches) {
console.warn("Your dataset iterator ran out of data during evaluateDataset(). Interrupting evalution. Make sure that your dataset can generate at least `batches` " + ("batches (in this case, " + args.batches + " batches). ") + "You may need to use the repeat() function when building your dataset.");
}
return [2, "break"];
}
return [2];
}
});
};
_b.label = 4;
case 4:
if (!(hasBatches ? batch < args.batches : true))
return [3, 6];
return [5, _loop_3()];
case 5:
state_1 = _b.sent();
if (state_1 === "break")
return [3, 6];
return [3, 4];
case 6:
for (i = 0; i < outs.length; ++i) {
oldScalar = outs[i];
outs[i] = tfc.div(outs[i], numExamples);
tfc.dispose(oldScalar);
}
return [2, singletonOrArray(outs)];
}
});
});
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function checkBatchSize(batchSize) {
tfc.util.assert(batchSize > 0 && Number.isInteger(batchSize), function() {
return "batchSize is required to be a positive integer, but got " + batchSize;
});
}
function sliceArrays(arrays, start, stop) {
if (arrays == null) {
return [null];
} else if (Array.isArray(arrays)) {
return arrays.map(function(array) {
return sliceAlongFirstAxis(array, start, stop - start);
});
} else {
return sliceAlongFirstAxis(arrays, start, stop - start);
}
}
function sliceArraysByIndices(arrays, indices) {
return tfc.tidy(function() {
if (arrays == null) {
return null;
} else if (Array.isArray(arrays)) {
return arrays.map(function(array) {
return sliceArraysByIndices(array, indices);
});
} else {
return gather(arrays, indices.dtype === "int32" ? indices : indices.toInt());
}
});
}
function makeBatches(size, batchSize) {
var output = [];
var batchStart = 0;
var batchEnd = null;
while (batchStart < size) {
batchEnd = batchStart + batchSize;
if (batchEnd >= size) {
batchEnd = size;
}
output.push([batchStart, batchEnd]);
batchStart = batchEnd;
}
return output;
}
function fitLoop(model2, f, ins, outLabels, batchSize, epochs, verbose, callbacks2, valF, valIns, shuffle, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) {
return __awaiter(this, void 0, void 0, function() {
var doValidation, numTrainSamples, indexArray, _a, callbackList, history, _loop_1, epoch, state_1;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (batchSize == null) {
batchSize = 32;
}
if (epochs == null) {
epochs = 1;
}
if (shuffle == null) {
shuffle = true;
}
if (initialEpoch == null) {
initialEpoch = 0;
}
doValidation = false;
if (valF != null && valIns != null) {
doValidation = true;
}
if (validationSteps != null) {
doValidation = true;
if (stepsPerEpoch == null) {
throw new ValueError("Can only use `validationSteps` when doing step-wise training, i.e., `stepsPerEpoch` must be set.");
}
}
numTrainSamples = model2.checkNumSamples(ins, batchSize, stepsPerEpoch, "steps_per_epoch");
if (numTrainSamples != null) {
indexArray = range(0, numTrainSamples);
}
if (verbose == null) {
verbose = 1;
}
_a = configureCallbacks(callbacks2, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics), callbackList = _a.callbackList, history = _a.history;
callbackList.setModel(model2);
model2.history = history;
return [4, callbackList.onTrainBegin()];
case 1:
_b.sent();
model2.stopTraining_ = false;
_loop_1 = function(epoch2) {
var epochLogs, epochIndexArray1D_1, batches_1, _loop_2, batchIndex, state_2;
return __generator(this, function(_a2) {
switch (_a2.label) {
case 0:
return [4, callbackList.onEpochBegin(epoch2)];
case 1:
_a2.sent();
epochLogs = {};
if (!(stepsPerEpoch != null))
return [3, 2];
throw new NotImplementedError("stepsPerEpoch mode is not implemented yet.");
case 2:
if (shuffle === "batch") {
throw new NotImplementedError("batch shuffling is not implemneted yet");
} else if (shuffle) {
tfc.util.shuffle(indexArray);
}
epochIndexArray1D_1 = tfc.tensor1d(indexArray);
batches_1 = makeBatches(numTrainSamples, batchSize);
_loop_2 = function(batchIndex2) {
var batchLogs;
return __generator(this, function(_a3) {
switch (_a3.label) {
case 0:
batchLogs = {};
return [4, callbackList.onBatchBegin(batchIndex2, batchLogs)];
case 1:
_a3.sent();
tfc.tidy(function() {
var batchStart = batches_1[batchIndex2][0];
var batchEnd = batches_1[batchIndex2][1];
var batchIds = sliceAlongFirstAxis(epochIndexArray1D_1, batchStart, batchEnd - batchStart);
batchLogs["batch"] = batchIndex2;
batchLogs["size"] = batchEnd - batchStart;
var insBatch = sliceArraysByIndices(ins, batchIds);
var outs = f(insBatch);
for (var i = 0; i < outLabels.length; ++i) {
var label = outLabels[i];
var out = outs[i];
batchLogs[label] = out;
tfc.keep(out);
}
if (batchIndex2 === batches_1.length - 1) {
if (doValidation) {
var valOuts = model2.testLoop(valF, valIns, batchSize);
for (var i = 0; i < outLabels.length; ++i) {
var label = outLabels[i];
var out = valOuts[i];
tfc.keep(out);
epochLogs["val_" + label] = out;
}
}
}
});
return [4, callbackList.onBatchEnd(batchIndex2, batchLogs)];
case 2:
_a3.sent();
disposeTensorsInLogs(batchLogs);
if (model2.stopTraining_) {
return [2, "break"];
}
return [2];
}
});
};
batchIndex = 0;
_a2.label = 3;
case 3:
if (!(batchIndex < batches_1.length))
return [3, 6];
return [5, _loop_2(batchIndex)];
case 4:
state_2 = _a2.sent();
if (state_2 === "break")
return [3, 6];
_a2.label = 5;
case 5:
++batchIndex;
return [3, 3];
case 6:
epochIndexArray1D_1.dispose();
_a2.label = 7;
case 7:
return [4, callbackList.onEpochEnd(epoch2, epochLogs)];
case 8:
_a2.sent();
if (model2.stopTraining_) {
return [2, "break"];
}
return [2];
}
});
};
epoch = initialEpoch;
_b.label = 2;
case 2:
if (!(epoch < epochs))
return [3, 5];
return [5, _loop_1(epoch)];
case 3:
state_1 = _b.sent();
if (state_1 === "break")
return [3, 5];
_b.label = 4;
case 4:
++epoch;
return [3, 2];
case 5:
return [4, callbackList.onTrainEnd()];
case 6:
_b.sent();
return [4, model2.history.syncData()];
case 7:
_b.sent();
return [2, model2.history];
}
});
});
}
function fitTensors(model2, x, y, args) {
if (args === void 0) {
args = {};
}
return __awaiter(this, void 0, void 0, function() {
var inputs, targets, inputValX, inputValY, valX, valY, sampleWeights, batchSize, checkBatchAxis, standardizedOuts, doValidation, valIns, checkBatchAxis_1, valStandardized, splitAt, originalBatchSize, ins, trainFunction, outLabels, valFunction, callbackMetrics, callbacks2, out;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (model2.isTraining) {
throw new Error("Cannot start training because another fit() call is ongoing.");
}
model2.isTraining = true;
_a.label = 1;
case 1:
_a.trys.push([1, , 7, 8]);
batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
checkBatchAxis = false;
return [4, model2.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize)];
case 2:
standardizedOuts = _a.sent();
inputs = standardizedOuts[0];
targets = standardizedOuts[1];
sampleWeights = standardizedOuts[2];
doValidation = false;
valIns = void 0;
if (!(args.validationData != null && args.validationData.length > 0))
return [3, 4];
doValidation = true;
if (args.validationData.length === 2) {
inputValX = args.validationData[0];
inputValY = args.validationData[1];
} else if (args.validationData.length === 3) {
throw new NotImplementedError("validationData including sample weights is not supported yet.");
} else {
throw new ValueError("When passing validation data, it must contain 2 (valX, valY) or 3 (valX, valY, valSampleWeight) items; " + (args.validationData + " is invalid."));
}
checkBatchAxis_1 = true;
return [4, model2.standardizeUserData(inputValX, inputValY, null, null, checkBatchAxis_1, batchSize)];
case 3:
valStandardized = _a.sent();
valX = valStandardized[0];
valY = valStandardized[1];
valIns = valX.concat(valY);
return [3, 5];
case 4:
if (args.validationSplit != null && args.validationSplit > 0 && args.validationSplit < 1) {
doValidation = true;
splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit));
originalBatchSize = inputs[0].shape[0];
valX = sliceArrays(inputs, splitAt, originalBatchSize);
inputs = sliceArrays(inputs, 0, splitAt);
valY = sliceArrays(targets, splitAt, originalBatchSize);
targets = sliceArrays(targets, 0, splitAt);
valIns = valX.concat(valY);
} else if (args.validationSteps != null) {
doValidation = true;
}
_a.label = 5;
case 5:
ins = inputs.concat(targets).concat(sampleWeights);
model2.checkTrainableWeightsConsistency();
trainFunction = model2.makeTrainFunction();
outLabels = model2.getDedupedMetricsNames();
valFunction = void 0;
callbackMetrics = void 0;
if (doValidation) {
model2.makeTestFunction();
valFunction = model2.testFunction;
callbackMetrics = outLabels.slice().concat(outLabels.map(function(n) {
return "val_" + n;
}));
} else {
valFunction = null;
valIns = [];
callbackMetrics = outLabels.slice();
}
callbacks2 = standardizeCallbacks(args.callbacks, args.yieldEvery);
return [4, fitLoop(model2, trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks2, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null)];
case 6:
out = _a.sent();
return [2, out];
case 7:
model2.isTraining = false;
disposeNewTensors(inputs, x);
disposeNewTensors(targets, y);
disposeNewTensors(valX, inputValX);
disposeNewTensors(valY, inputValY);
if (sampleWeights != null) {
tfc.dispose(sampleWeights);
}
return [7];
case 8:
return [2];
}
});
});
}
function ensureTensorsRank2OrHigher(tensors) {
var outs = [];
if (tensors instanceof tfc.Tensor) {
tensors = [tensors];
}
for (var i = 0; i < tensors.length; ++i) {
var tensor = tensors[i];
if (tensor.rank === 1) {
outs.push(expandDims(tensor, 1));
} else if (tensor.rank === 0) {
throw new Error("Expected tensor to be at least 1D, but received a 0D tensor (scalar).");
} else {
outs.push(tensor);
}
}
return outs;
}
function disposeNewTensors(tensors, refTensors) {
if (tensors == null) {
return;
}
var oldTensorIds = [];
if (refTensors instanceof tfc.Tensor) {
oldTensorIds.push(refTensors.id);
} else if (Array.isArray(refTensors)) {
refTensors.forEach(function(t) {
return oldTensorIds.push(t.id);
});
} else if (refTensors != null) {
for (var name_1 in refTensors) {
var oldTensor = refTensors[name_1];
oldTensorIds.push(oldTensor.id);
}
}
var tensorsToDispose = [];
if (tensors instanceof tfc.Tensor) {
if (oldTensorIds.indexOf(tensors.id) === -1) {
tensorsToDispose.push(tensors);
}
} else if (Array.isArray(tensors)) {
tensors.forEach(function(t) {
if (oldTensorIds.indexOf(t.id) === -1) {
tensorsToDispose.push(t);
}
});
} else if (tensors != null) {
for (var name_2 in tensors) {
var tensor = tensors[name_2];
if (oldTensorIds.indexOf(tensor.id) === -1) {
tensorsToDispose.push(tensor);
}
}
}
tensorsToDispose.forEach(function(t) {
if (!t.isDisposed) {
t.dispose();
}
});
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function isDataTensor(x) {
return x instanceof tfc.Tensor;
}
function isDataArray(x) {
return Array.isArray(x);
}
function isDataDict(x) {
return !isDataTensor(x) && !isDataArray(x);
}
function standardizeInputData(data, names, shapes, checkBatchAxis, exceptionPrefix) {
if (checkBatchAxis === void 0) {
checkBatchAxis = true;
}
if (exceptionPrefix === void 0) {
exceptionPrefix = "";
}
if (names == null || names.length === 0) {
if (data != null) {
var gotUnexpectedData = false;
if (isDataArray(data) && data.length > 0) {
gotUnexpectedData = true;
} else if (isDataDict(data)) {
for (var key in data) {
if (data.hasOwnProperty(key)) {
gotUnexpectedData = true;
break;
}
}
} else {
gotUnexpectedData = true;
}
if (gotUnexpectedData) {
throw new ValueError("Error when checking model " + exceptionPrefix + " expected no data, " + ("but got " + data));
}
}
return [];
}
if (data == null) {
return names.map(function(name) {
return null;
});
}
var arrays;
if (isDataDict(data)) {
data = data;
arrays = [];
for (var _i = 0, names_1 = names; _i < names_1.length; _i++) {
var name_1 = names_1[_i];
if (data[name_1] == null) {
throw new ValueError('No data provided for "' + name_1 + '". Need data for each key in: ' + ("" + names));
}
arrays.push(data[name_1]);
}
} else if (isDataArray(data)) {
data = data;
if (data.length !== names.length) {
throw new ValueError("Error when checking model " + exceptionPrefix + ": the Array of Tensors that you are passing to your model is not the size the " + ("model expected. Expected to see " + names.length + " Tensor(s), but ") + ("instead got the following list of Tensor(s): " + data));
}
arrays = data;
} else {
data = data;
if (names.length > 1) {
throw new ValueError("The model " + exceptionPrefix + " expects " + names.length + " Tensor(s), " + ("but only received one Tensor. Found: Tensor with shape " + data.shape));
}
arrays = [data];
}
arrays = ensureTensorsRank2OrHigher(arrays);
if (shapes != null) {
for (var i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
var array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError("Error when checking " + exceptionPrefix + ": expected " + names[i] + " " + ("to have " + shapes[i].length + " dimension(s). but got array with ") + ("shape " + array.shape));
}
for (var j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
continue;
}
var dim = array.shape[j];
var refDim = shapes[i][j];
if (refDim != null && refDim >= 0 && dim !== refDim) {
throw new ValueError("Error when checking " + exceptionPrefix + ": expected " + names[i] + " " + ("to have shape [" + shapes[i] + "], but got array with shape ") + ("[" + array.shape + "]."));
}
}
}
}
return arrays;
}
function checkArrayLengths(inputs, targets, weights) {
var setX = unique(inputs.map(function(input2) {
return input2.shape[0];
}));
setX.sort();
var setY = unique(targets.map(function(target) {
return target.shape[0];
}));
setY.sort();
if (setX.length > 1) {
throw new ValueError("All input Tensors (x) should have the same number of samples. Got array shapes: " + ("" + JSON.stringify(inputs.map(function(input2) {
return input2.shape;
}))));
}
if (setY.length > 1) {
throw new ValueError("All target Tensors (y) should have the same number of samples. Got array shapes: " + ("" + JSON.stringify(targets.map(function(target) {
return target.shape;
}))));
}
if (setX.length > 0 && setY.length > 0 && !tfc.util.arraysEqual(setX, setY)) {
throw new ValueError("Input Tensors should have the same number of samples as target " + ("Tensors. Found " + setX[0] + " input sample(s) and " + setY[0] + " target ") + "sample(s).");
}
}
function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
var keyLosses = [
meanSquaredError,
binaryCrossentropy,
categoricalCrossentropy
];
for (var i = 0; i < targets.length; ++i) {
var y = targets[i];
var loss = lossFns[i];
var shape = outputShapes[i];
if (loss == null) {
continue;
}
if (loss === categoricalCrossentropy) {
if (y.shape[y.shape.length - 1] === 1) {
throw new ValueError("You are passing a target array of shape " + y.shape + " while using a loss 'categorical_crossentropy'. 'categorical_crossentropy'expects targets to be binary matrices (1s and 0s) of shape [samples, classes].");
}
}
if (keyLosses.indexOf(loss) !== -1) {
var slicedYShape = y.shape.slice(1);
var slicedShape = shape.slice(1);
for (var j = 0; j < slicedYShape.length; ++j) {
var targetDim = slicedYShape[j];
var outDim = slicedShape[j];
if (outDim != null && targetDim !== outDim) {
throw new ValueError("A target Tensor with shape " + y.shape + " was passed for an " + ("output of shape " + shape + ", while using a loss function that ") + "expects targets to have the same shape as the output.");
}
}
}
}
}
function checkInputData(data, names, shapes, checkBatchAxis, exceptionPrefix) {
if (checkBatchAxis === void 0) {
checkBatchAxis = true;
}
if (exceptionPrefix === void 0) {
exceptionPrefix = "";
}
var arrays;
if (Array.isArray(data)) {
if (data.length !== names.length) {
throw new ValueError("Error when checking model " + exceptionPrefix + ": the Array of Tensors that you are passing to your model is not the size the " + ("the model expected. Expected to see " + names.length + " Tensor(s),") + (" but instead got " + data.length + " Tensors(s)."));
}
arrays = data;
} else {
if (names.length > 1) {
throw new ValueError("The model expects " + names.length + " " + exceptionPrefix + " Tensors, but only received one Tensor. Found: array with shape " + (JSON.stringify(data.shape) + "."));
}
arrays = [data];
}
if (shapes != null) {
for (var i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
var array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError("Error when checking " + exceptionPrefix + ": expected " + names[i] + " " + ("to have " + shapes[i].length + " dimension(s), but got array with ") + ("shape " + JSON.stringify(array.shape)));
}
for (var j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
continue;
}
var dim = array.shape[j];
var refDim = shapes[i][j];
if (refDim != null) {
if (refDim !== dim) {
throw new ValueError("Error when checking " + exceptionPrefix + ": expected " + (names[i] + " to have shape " + JSON.stringify(shapes[i]) + " but ") + ("got array with shape " + JSON.stringify(array.shape) + "."));
}
}
}
}
}
}
function collectMetrics(metrics, outputNames) {
if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
return outputNames.map(function(name) {
return [];
});
}
var wrappedMetrics;
if (typeof metrics === "string" || typeof metrics === "function") {
wrappedMetrics = [metrics];
} else if (Array.isArray(metrics) || typeof metrics === "object") {
wrappedMetrics = metrics;
} else {
throw new TypeError("Type of metrics argument not understood. Expected an string," + ("function, Array, or Object, found: " + metrics));
}
if (Array.isArray(wrappedMetrics)) {
return outputNames.map(function(name) {
return wrappedMetrics;
});
} else {
var nestedMetrics = [];
for (var _i = 0, outputNames_1 = outputNames; _i < outputNames_1.length; _i++) {
var name_2 = outputNames_1[_i];
var outputMetrics = wrappedMetrics.hasOwnProperty(name_2) ? wrappedMetrics[name_2] : [];
if (!Array.isArray(outputMetrics)) {
outputMetrics = [outputMetrics];
}
nestedMetrics.push(outputMetrics);
}
return nestedMetrics;
}
}
var LAYERS_MODEL_FORMAT_NAME = "layers-model";
var LayersModel = function(_super) {
__extends(LayersModel2, _super);
function LayersModel2(args) {
var _this = _super.call(this, args) || this;
_this.isTraining = false;
return _this;
}
LayersModel2.prototype.summary = function(lineLength, positions, printFn) {
if (printFn === void 0) {
printFn = console.log;
}
if (!this.built) {
throw new ValueError("This model has never been called, thus its weights have not been created yet. So no summary can be displayed. Build the model first (e.g., by calling it on some test data).");
}
printSummary(this, lineLength, positions, printFn);
};
LayersModel2.prototype.compile = function(args) {
var _this = this;
if (args.loss == null) {
args.loss = [];
}
this.loss = args.loss;
if (typeof args.optimizer === "string") {
this.optimizer_ = getOptimizer(args.optimizer);
this.isOptimizerOwned = true;
} else {
if (!(args.optimizer instanceof tfc.Optimizer)) {
throw new ValueError("User-defined optimizer must be an instance of tf.Optimizer.");
}
this.optimizer_ = args.optimizer;
this.isOptimizerOwned = false;
}
var lossFunctions = [];
if (!Array.isArray(args.loss) && typeof args.loss !== "string" && typeof args.loss !== "function") {
args.loss = args.loss;
for (var name_3 in args.loss) {
if (this.outputNames.indexOf(name_3) === -1) {
throw new ValueError('Unknown entry in loss dictionary: "' + name_3 + '". ' + ("Only expected the following keys: " + this.outputNames));
}
}
for (var _i = 0, _a = this.outputNames; _i < _a.length; _i++) {
var name_4 = _a[_i];
if (args.loss[name_4] == null) {
console.warn('Output "' + name_4 + '" is missing from loss dictionary. We assume this was done on purpose, and we will not be expecting data ' + ("to be passed to " + name_4 + " during training"));
}
lossFunctions.push(get(args.loss[name_4]));
}
} else if (Array.isArray(args.loss)) {
if (args.loss.length !== this.outputs.length) {
throw new ValueError("When passing an Array as loss, it should have one entry per " + ("model output. The model has " + this.outputs.length + " output(s), ") + ("but you passed loss=" + args.loss + "."));
}
var theLosses = args.loss;
lossFunctions = theLosses.map(function(l) {
return get(l);
});
} else {
var lossFunction_1 = get(args.loss);
this.outputs.forEach(function(_) {
lossFunctions.push(lossFunction_1);
});
}
this.lossFunctions = lossFunctions;
this.feedOutputNames = [];
this.feedOutputShapes = [];
this.feedLossFns = [];
for (var i = 0; i < this.outputs.length; ++i) {
var shape = this.internalOutputShapes[i];
var name_5 = this.outputNames[i];
this.feedOutputNames.push(name_5);
this.feedOutputShapes.push(shape);
this.feedLossFns.push(this.lossFunctions[i]);
}
var skipTargetIndices = [];
this.metrics = args.metrics;
this.metricsNames = ["loss"];
this.metricsTensors = [];
nameScope("loss", function() {
for (var i2 = 0; i2 < _this.outputs.length; ++i2) {
if (skipTargetIndices.indexOf(i2) !== -1) {
continue;
}
var weightedLoss = _this.lossFunctions[i2];
if (_this.outputs.length > 1) {
_this.metricsTensors.push([weightedLoss, i2]);
_this.metricsNames.push(_this.outputNames[i2] + "_loss");
}
}
});
var nestedMetrics = collectMetrics(args.metrics, this.outputNames);
var appendMetric = function(outputIndex, metricName, metricTensor) {
if (_this.outputNames.length > 1) {
metricName = _this.outputNames[outputIndex] + "_" + metricName;
}
_this.metricsNames.push(metricName);
_this.metricsTensors.push([metricTensor, outputIndex]);
};
nameScope("metric", function() {
var _loop_1 = function(i3) {
if (skipTargetIndices.indexOf(i3) !== -1) {
return "continue";
}
var outputMetrics = nestedMetrics[i3];
var handleMetrics = function(metrics) {
var metricNamePrefix = "";
var metricName;
var accFn;
var weightedMetricFn;
var _loop_2 = function(metric2) {
if (typeof metric2 === "string" && ["accuracy", "acc", "crossentropy", "ce"].indexOf(metric2) !== -1) {
var outputShape = _this.internalOutputShapes[i3];
if (outputShape[outputShape.length - 1] === 1 || _this.lossFunctions[i3] === binaryCrossentropy) {
if (["accuracy", "acc"].indexOf(metric2) !== -1) {
accFn = binaryAccuracy;
} else if (["crossentropy", "ce"].indexOf(metric2) !== -1) {
accFn = binaryCrossentropy$1;
}
} else if (_this.lossFunctions[i3] === sparseCategoricalCrossentropy) {
if (["accuracy", "acc"].indexOf(metric2) !== -1) {
accFn = sparseCategoricalAccuracy;
} else if (["crossentropy", "ce"].indexOf(metric2) !== -1) {
accFn = sparseCategoricalCrossentropy$1;
}
} else {
if (["accuracy", "acc"].indexOf(metric2) !== -1) {
accFn = categoricalAccuracy;
} else if (["crossentropy", "ce"].indexOf(metric2) !== -1) {
accFn = categoricalCrossentropy$1;
}
}
var suffix = void 0;
if (["accuracy", "acc"].indexOf(metric2) !== -1) {
suffix = "acc";
} else if (["crossentropy", "ce"].indexOf(metric2) !== -1) {
suffix = "ce";
}
weightedMetricFn = accFn;
metricName = metricNamePrefix + suffix;
} else {
var metricFn = get$1(metric2);
weightedMetricFn = metricFn;
metricName = metricNamePrefix + getLossOrMetricName(metric2);
}
var metricResult;
nameScope(metricName, function() {
metricResult = weightedMetricFn;
});
appendMetric(i3, metricName, metricResult);
};
for (var _i2 = 0, metrics_1 = metrics; _i2 < metrics_1.length; _i2++) {
var metric = metrics_1[_i2];
_loop_2(metric);
}
};
handleMetrics(outputMetrics);
};
for (var i2 = 0; i2 < _this.outputs.length; ++i2) {
_loop_1(i2);
}
});
this.collectedTrainableWeights = this.trainableWeights;
};
LayersModel2.prototype.checkTrainableWeightsConsistency = function() {
if (this.collectedTrainableWeights == null) {
return;
}
if (this.trainableWeights.length !== this.collectedTrainableWeights.length) {
console.warn("Discrepancy between trainableweights and collected trainable weights. Did you set `model.trainable` without calling `model.compile()` afterwards?");
}
};
LayersModel2.prototype.evaluate = function(x, y, args) {
if (args === void 0) {
args = {};
}
var batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
var checkBatchAxis = true;
var standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
try {
var ins = standardizedOuts[0].concat(standardizedOuts[1]);
this.makeTestFunction();
var f = this.testFunction;
var testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
return singletonOrArray(testOuts);
} finally {
disposeNewTensors(standardizedOuts[0], x);
disposeNewTensors(standardizedOuts[1], y);
}
};
LayersModel2.prototype.evaluateDataset = function(dataset, args) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
this.makeTestFunction();
return [2, evaluateDataset(this, dataset, args)];
});
});
};
LayersModel2.prototype.checkNumSamples = function(ins, batchSize, steps, stepsName) {
if (stepsName === void 0) {
stepsName = "steps";
}
var numSamples;
if (steps != null) {
numSamples = null;
if (batchSize != null) {
throw new ValueError("If " + stepsName + " is set, batchSize must be null or undefined." + ("Got batchSize = " + batchSize));
}
} else if (ins != null) {
if (Array.isArray(ins)) {
numSamples = ins[0].shape[0];
} else {
numSamples = ins.shape[0];
}
} else {
throw new ValueError("Either the input data should have a defined shape, or " + (stepsName + " shoud be specified."));
}
return numSamples;
};
LayersModel2.prototype.execute = function(inputs, outputs) {
if (Array.isArray(outputs) && outputs.length === 0) {
throw new ValueError("`outputs` is an empty Array, which is not allowed.");
}
var outputsIsArray = Array.isArray(outputs);
var outputNames = outputsIsArray ? outputs : [outputs];
var outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
var feedDict = new FeedDict();
if (inputs instanceof tfc.Tensor) {
inputs = [inputs];
}
if (Array.isArray(inputs)) {
if (inputs.length !== this.inputs.length) {
throw new ValueError("The number of inputs provided (" + inputs.length + ") does not match the number of inputs of this model " + ("(" + this.inputs.length + ")."));
}
for (var i = 0; i < this.inputs.length; ++i) {
feedDict.add(this.inputs[i], inputs[i]);
}
} else {
for (var _i = 0, _a = this.inputs; _i < _a.length; _i++) {
var input2 = _a[_i];
var tensorValue = inputs[input2.name];
if (tensorValue == null) {
throw new ValueError("No value is provided for the model's input " + input2.name);
}
feedDict.add(input2, tensorValue);
}
}
var executeOutputs = execute(outputSymbolicTensors, feedDict);
return outputsIsArray ? executeOutputs : executeOutputs[0];
};
LayersModel2.prototype.retrieveSymbolicTensors = function(symbolicTensorNames) {
var outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
var outputsRemaining = symbolicTensorNames.length;
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
var layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
var layerOutputNames = layerOutputs.map(function(output) {
return output.name;
});
for (var i = 0; i < symbolicTensorNames.length; ++i) {
var index = layerOutputNames.indexOf(symbolicTensorNames[i]);
if (index !== -1) {
outputSymbolicTensors[i] = layerOutputs[index];
outputsRemaining--;
}
if (outputsRemaining === 0) {
break;
}
}
if (outputsRemaining === 0) {
break;
}
}
if (outputsRemaining > 0) {
var remainingNames_1 = [];
outputSymbolicTensors.forEach(function(tensor, i2) {
if (tensor == null) {
remainingNames_1.push(symbolicTensorNames[i2]);
}
});
throw new ValueError("Cannot find SymbolicTensors for output name(s): " + ("" + JSON.stringify(remainingNames_1)));
}
return outputSymbolicTensors;
};
LayersModel2.prototype.predictLoop = function(ins, batchSize, verbose) {
var _this = this;
if (batchSize === void 0) {
batchSize = 32;
}
if (verbose === void 0) {
verbose = false;
}
return tfc.tidy(function() {
var numSamples = _this.checkNumSamples(ins);
if (verbose) {
throw new NotImplementedError("Verbose predictLoop() is not implemented yet.");
}
var batches = makeBatches(numSamples, batchSize);
var outsBatches = _this.outputs.map(function(output) {
return [];
});
var _loop_3 = function(batchIndex2) {
var batchOuts = tfc.tidy(function() {
var batchStart = batches[batchIndex2][0];
var batchEnd = batches[batchIndex2][1];
var insBatch = sliceArrays(ins, batchStart, batchEnd);
var feeds = [];
if (Array.isArray(insBatch)) {
for (var i = 0; i < insBatch.length; ++i) {
feeds.push({key: _this.inputs[i], value: insBatch[i]});
}
} else {
feeds.push({key: _this.inputs[0], value: insBatch});
}
var feedDict = new FeedDict(feeds);
return execute(_this.outputs, feedDict);
});
batchOuts.forEach(function(batchOut, i) {
return outsBatches[i].push(batchOut);
});
};
for (var batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
_loop_3(batchIndex);
}
return singletonOrArray(outsBatches.map(function(batches2) {
return tfc.concat(batches2, 0);
}));
});
};
LayersModel2.prototype.predict = function(x, args) {
if (args === void 0) {
args = {};
}
var xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
try {
var batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
return this.predictLoop(xsRank2OrHigher, batchSize);
} finally {
disposeNewTensors(xsRank2OrHigher, x);
}
};
LayersModel2.prototype.predictOnBatch = function(x) {
checkInputData(x, this.inputNames, this.feedInputShapes, true);
var batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
return this.predictLoop(x, batchSize);
};
LayersModel2.prototype.standardizeUserDataXY = function(x, y, checkBatchAxis, batchSize) {
if (this.optimizer_ == null) {
throw new RuntimeError("You must compile a model before training/testing. Use LayersModel.compile(modelCompileArgs).");
}
var outputShapes = [];
for (var i = 0; i < this.feedOutputShapes.length; ++i) {
var outputShape = this.feedOutputShapes[i];
var lossFn = this.feedLossFns[i];
if (lossFn === sparseCategoricalCrossentropy) {
outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
} else {
outputShapes.push(outputShape);
}
}
x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, "input");
y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, "target");
checkArrayLengths(x, y);
checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
if (this.stateful && batchSize != null && batchSize > 0) {
if (x[0].shape[0] % batchSize !== 0) {
throw new ValueError("In a stateful network, you should only pass inputs with a number of samples that is divisible by the batch size " + (batchSize + ". Found: " + x[0].shape[0] + " sample(s)."));
}
}
return [x, y];
};
LayersModel2.prototype.standardizeUserData = function(x, y, sampleWeight, classWeight, checkBatchAxis, batchSize) {
if (checkBatchAxis === void 0) {
checkBatchAxis = true;
}
return __awaiter(this, void 0, void 0, function() {
var _a, standardXs, standardYs, standardSampleWeights, classWeights, i, _b, _c;
return __generator(this, function(_d) {
switch (_d.label) {
case 0:
_a = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize), standardXs = _a[0], standardYs = _a[1];
if (sampleWeight != null) {
throw new Error("sample weight is not supported yet.");
}
standardSampleWeights = null;
if (!(classWeight != null))
return [3, 4];
classWeights = standardizeClassWeights(classWeight, this.outputNames);
standardSampleWeights = [];
i = 0;
_d.label = 1;
case 1:
if (!(i < classWeights.length))
return [3, 4];
_c = (_b = standardSampleWeights).push;
return [4, standardizeWeights(standardYs[i], null, classWeights[i])];
case 2:
_c.apply(_b, [_d.sent()]);
_d.label = 3;
case 3:
++i;
return [3, 1];
case 4:
return [2, [standardXs, standardYs, standardSampleWeights]];
}
});
});
};
LayersModel2.prototype.testLoop = function(f, ins, batchSize, verbose, steps) {
var _this = this;
if (verbose === void 0) {
verbose = 0;
}
return tfc.tidy(function() {
var numSamples = _this.checkNumSamples(ins, batchSize, steps, "steps");
var outs = [];
if (verbose > 0) {
throw new NotImplementedError("Verbose mode is not implemented yet.");
}
if (steps != null) {
throw new NotImplementedError("steps mode in testLoop() is not implemented yet");
} else {
var batches = makeBatches(numSamples, batchSize);
var indexArray = tfc.tensor1d(range(0, numSamples));
for (var batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
var batchStart = batches[batchIndex][0];
var batchEnd = batches[batchIndex][1];
var batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
var insBatch = sliceArraysByIndices(ins, batchIds);
var batchOuts = f(insBatch);
if (batchIndex === 0) {
for (var i = 0; i < batchOuts.length; ++i) {
outs.push(tfc.scalar(0));
}
}
for (var i = 0; i < batchOuts.length; ++i) {
var batchOut = batchOuts[i];
outs[i] = tfc.add(outs[i], tfc.mul(batchEnd - batchStart, batchOut));
}
}
for (var i = 0; i < outs.length; ++i) {
outs[i] = tfc.div(outs[i], numSamples);
}
}
return outs;
});
};
LayersModel2.prototype.getDedupedMetricsNames = function() {
var outLabels = this.metricsNames;
var dedupedOutLabels = [];
for (var i = 0; i < outLabels.length; ++i) {
var label = outLabels[i];
var newLabel = label;
if (count(outLabels, label) > 1) {
var dupIndex = count(outLabels.slice(0, i), label);
newLabel += "_" + dupIndex;
}
dedupedOutLabels.push(newLabel);
}
return dedupedOutLabels;
};
LayersModel2.prototype.makeTrainFunction = function() {
var _this = this;
return function(data) {
var lossValues = [];
var inputs = data.slice(0, _this.inputs.length);
var targets = data.slice(_this.inputs.length, _this.inputs.length + _this.outputs.length);
var sampleWeights = data.slice(_this.inputs.length + _this.outputs.length, _this.inputs.length + _this.outputs.length * 2);
var metricsValues = [];
var totalLossFunction = function() {
var feeds = [];
for (var i = 0; i < _this.inputs.length; ++i) {
feeds.push({key: _this.inputs[i], value: inputs[i]});
}
var feedDict = new FeedDict(feeds);
var outputs = execute(_this.outputs, feedDict, {training: true});
var totalLoss;
for (var i = 0; i < _this.lossFunctions.length; ++i) {
var lossFunction = _this.lossFunctions[i];
var loss = lossFunction(targets[i], outputs[i]);
if (sampleWeights[i] != null) {
loss = computeWeightedLoss(loss, sampleWeights[i]);
}
var meanLoss = tfc.mean(loss);
lossValues.push(meanLoss);
if (i === 0) {
totalLoss = loss;
} else {
totalLoss = tfc.add(totalLoss, loss);
}
}
for (var i = 0; i < _this.metricsTensors.length; ++i) {
var weightedMetric = void 0;
if (_this.outputs.length > 1 && i < _this.outputs.length) {
weightedMetric = lossValues[i];
} else {
var metric = _this.metricsTensors[i][0];
var outputIndex = _this.metricsTensors[i][1];
weightedMetric = tfc.mean(metric(targets[outputIndex], outputs[outputIndex]));
}
tfc.keep(weightedMetric);
metricsValues.push(weightedMetric);
}
totalLoss = tfc.mean(totalLoss);
_this.calculateLosses().forEach(function(regularizerLoss) {
totalLoss = tfc.add(totalLoss, regularizerLoss);
});
return totalLoss;
};
var variables = _this.collectedTrainableWeights.map(function(param) {
return param.read();
});
var returnCost = true;
var totalLossValue = _this.optimizer_.minimize(totalLossFunction, returnCost, variables);
return [totalLossValue].concat(metricsValues);
};
};
LayersModel2.prototype.makeTestFunction = function() {
var _this = this;
this.testFunction = function(data) {
return tfc.tidy(function() {
var valOutputs = [];
var totalLoss;
var inputs = data.slice(0, _this.inputs.length);
var targets = data.slice(_this.inputs.length, _this.inputs.length + _this.outputs.length);
var feeds = [];
for (var i = 0; i < _this.inputs.length; ++i) {
feeds.push({key: _this.inputs[i], value: inputs[i]});
}
var feedDict = new FeedDict(feeds);
var outputs = execute(_this.outputs, feedDict);
for (var i = 0; i < _this.lossFunctions.length; ++i) {
var lossFunction = _this.lossFunctions[i];
var loss = tfc.mean(lossFunction(targets[i], outputs[i]));
if (i === 0) {
totalLoss = loss;
} else {
totalLoss = tfc.add(totalLoss, loss);
}
valOutputs.push(totalLoss);
}
for (var i = 0; i < _this.metricsTensors.length; ++i) {
var metric = _this.metricsTensors[i][0];
var outputIndex = _this.metricsTensors[i][1];
var meanMetric = tfc.mean(metric(targets[outputIndex], outputs[outputIndex]));
valOutputs.push(meanMetric);
}
return valOutputs;
});
};
};
LayersModel2.prototype.fit = function(x, y, args) {
if (args === void 0) {
args = {};
}
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, fitTensors(this, x, y, args)];
});
});
};
LayersModel2.prototype.fitDataset = function(dataset, args) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, fitDataset(this, dataset, args)];
});
});
};
LayersModel2.prototype.trainOnBatch = function(x, y) {
return __awaiter(this, void 0, void 0, function() {
var standardizeOut, inputs, targets, trainFunction, losses, lossValues, _i, losses_1, loss, v;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.standardizeUserData(x, y)];
case 1:
standardizeOut = _a.sent();
inputs = standardizeOut[0];
targets = standardizeOut[1];
trainFunction = this.makeTrainFunction();
losses = trainFunction(inputs.concat(targets));
lossValues = [];
_i = 0, losses_1 = losses;
_a.label = 2;
case 2:
if (!(_i < losses_1.length))
return [3, 5];
loss = losses_1[_i];
return [4, loss.data()];
case 3:
v = _a.sent();
lossValues.push(v[0]);
_a.label = 4;
case 4:
_i++;
return [3, 2];
case 5:
tfc.dispose(losses);
return [2, singletonOrArray(lossValues)];
}
});
});
};
LayersModel2.prototype.getNamedWeights = function(config) {
var namedWeights = [];
var trainableOnly = config != null && config.trainableOnly;
var weights = trainableOnly ? this.trainableWeights : this.weights;
var weightValues = this.getWeights(trainableOnly);
for (var i = 0; i < weights.length; ++i) {
if (trainableOnly && !weights[i].trainable) {
continue;
}
namedWeights.push({name: weights[i].originalName, tensor: weightValues[i]});
}
return namedWeights;
};
Object.defineProperty(LayersModel2.prototype, "stopTraining", {
get: function() {
return this.stopTraining_;
},
set: function(stop) {
this.stopTraining_ = stop;
},
enumerable: true,
configurable: true
});
Object.defineProperty(LayersModel2.prototype, "optimizer", {
get: function() {
return this.optimizer_;
},
set: function(optimizer) {
if (this.optimizer_ !== optimizer) {
this.optimizer_ = optimizer;
this.isOptimizerOwned = false;
}
},
enumerable: true,
configurable: true
});
LayersModel2.prototype.dispose = function() {
var result = _super.prototype.dispose.call(this);
if (result.refCountAfterDispose === 0 && this.optimizer != null && this.isOptimizerOwned) {
var numTensorsBeforeOptmizerDisposal = tfc.memory().numTensors;
this.optimizer_.dispose();
result.numDisposedVariables += numTensorsBeforeOptmizerDisposal - tfc.memory().numTensors;
}
return result;
};
LayersModel2.prototype.getLossIdentifiers = function() {
var lossNames;
if (typeof this.loss === "string") {
lossNames = toSnakeCase(this.loss);
} else if (Array.isArray(this.loss)) {
for (var _i = 0, _a = this.loss; _i < _a.length; _i++) {
var loss = _a[_i];
if (typeof loss !== "string") {
throw new Error("Serialization of non-string loss is not supported.");
}
}
lossNames = this.loss.map(function(name) {
return toSnakeCase(name);
});
} else {
var outputNames = Object.keys(this.loss);
lossNames = {};
var losses_2 = this.loss;
for (var _b = 0, outputNames_2 = outputNames; _b < outputNames_2.length; _b++) {
var outputName = outputNames_2[_b];
if (typeof losses_2[outputName] === "string") {
lossNames[outputName] = toSnakeCase(losses_2[outputName]);
} else {
throw new Error("Serialization of non-string loss is not supported.");
}
}
}
return lossNames;
};
LayersModel2.prototype.getMetricIdentifiers = function() {
if (typeof this.metrics === "string" || typeof this.metrics === "function") {
return [toSnakeCase(getLossOrMetricName(this.metrics))];
} else if (Array.isArray(this.metrics)) {
return this.metrics.map(function(metric) {
return toSnakeCase(getLossOrMetricName(metric));
});
} else {
var metricsIdentifiers = {};
for (var key in this.metrics) {
metricsIdentifiers[key] = toSnakeCase(getLossOrMetricName(this.metrics[key]));
}
return metricsIdentifiers;
}
};
LayersModel2.prototype.getTrainingConfig = function() {
return {
loss: this.getLossIdentifiers(),
metrics: this.getMetricIdentifiers(),
optimizer_config: {
class_name: this.optimizer.getClassName(),
config: this.optimizer.getConfig()
}
};
};
LayersModel2.prototype.loadTrainingConfig = function(trainingConfig) {
if (trainingConfig.weighted_metrics != null) {
throw new Error("Loading weight_metrics is not supported yet.");
}
if (trainingConfig.loss_weights != null) {
throw new Error("Loading loss_weights is not supported yet.");
}
if (trainingConfig.sample_weight_mode != null) {
throw new Error("Loading sample_weight_mode is not supported yet.");
}
var tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
var optimizer = deserialize(tsConfig);
var loss;
if (typeof trainingConfig.loss === "string") {
loss = toCamelCase(trainingConfig.loss);
} else if (Array.isArray(trainingConfig.loss)) {
loss = trainingConfig.loss.map(function(lossEntry) {
return toCamelCase(lossEntry);
});
} else if (trainingConfig.loss != null) {
loss = {};
for (var key in trainingConfig.loss) {
loss[key] = toCamelCase(trainingConfig.loss[key]);
}
}
var metrics;
if (Array.isArray(trainingConfig.metrics)) {
metrics = trainingConfig.metrics.map(function(metric) {
return toCamelCase(metric);
});
} else if (trainingConfig.metrics != null) {
metrics = {};
for (var key in trainingConfig.metrics) {
metrics[key] = toCamelCase(trainingConfig.metrics[key]);
}
}
this.compile({loss, metrics, optimizer});
};
LayersModel2.prototype.save = function(handlerOrURL, config) {
return __awaiter(this, void 0, void 0, function() {
var handlers, weightDataAndSpecs, returnString, unusedArg, modelConfig, modelArtifacts, includeOptimizer, weightType, _a, optimizerWeightData, optimizerWeightSpecs, _b, _c, checkSize;
var _d;
return __generator(this, function(_e) {
switch (_e.label) {
case 0:
if (typeof handlerOrURL === "string") {
handlers = tfc.io.getSaveHandlers(handlerOrURL);
if (handlers.length === 0) {
throw new ValueError("Cannot find any save handlers for URL '" + handlerOrURL + "'");
} else if (handlers.length > 1) {
throw new ValueError("Found more than one (" + handlers.length + ") save handlers for " + ("URL '" + handlerOrURL + "'"));
}
handlerOrURL = handlers[0];
}
if (handlerOrURL.save == null) {
throw new ValueError("LayersModel.save() cannot proceed because the IOHandler provided does not have the `save` attribute defined.");
}
return [4, tfc.io.encodeWeights(this.getNamedWeights(config))];
case 1:
weightDataAndSpecs = _e.sent();
returnString = false;
unusedArg = null;
modelConfig = this.toJSON(unusedArg, returnString);
modelArtifacts = {
modelTopology: modelConfig,
format: LAYERS_MODEL_FORMAT_NAME,
generatedBy: "TensorFlow.js tfjs-layers v" + version,
convertedBy: null
};
includeOptimizer = config == null ? false : config.includeOptimizer;
if (!(includeOptimizer && this.optimizer != null))
return [3, 4];
modelArtifacts.trainingConfig = this.getTrainingConfig();
weightType = "optimizer";
_c = (_b = tfc.io).encodeWeights;
return [4, this.optimizer.getWeights()];
case 2:
return [4, _c.apply(_b, [_e.sent(), weightType])];
case 3:
_a = _e.sent(), optimizerWeightData = _a.data, optimizerWeightSpecs = _a.specs;
(_d = weightDataAndSpecs.specs).push.apply(_d, optimizerWeightSpecs);
weightDataAndSpecs.data = tfc.io.concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
_e.label = 4;
case 4:
if (this.userDefinedMetadata != null) {
checkSize = true;
checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
}
modelArtifacts.weightData = weightDataAndSpecs.data;
modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
return [2, handlerOrURL.save(modelArtifacts)];
}
});
});
};
LayersModel2.prototype.setUserDefinedMetadata = function(userDefinedMetadata) {
checkUserDefinedMetadata(userDefinedMetadata, this.name);
this.userDefinedMetadata = userDefinedMetadata;
};
LayersModel2.prototype.getUserDefinedMetadata = function() {
return this.userDefinedMetadata;
};
LayersModel2.className = "Model";
return LayersModel2;
}(Container);
tfc.serialization.registerClass(LayersModel);
var Functional = function(_super) {
__extends(Functional2, _super);
function Functional2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Functional2.className = "Functional";
return Functional2;
}(LayersModel);
tfc.serialization.registerClass(Functional);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function modelFromJSON(modelAndWeightsConfig, customObjects) {
return __awaiter(this, void 0, void 0, function() {
var modelTopology, tsConfig, model2, weightValues, uniqueWeightValues, _i, _a, weight;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (!("modelTopology" in modelAndWeightsConfig)) {
modelAndWeightsConfig = {modelTopology: modelAndWeightsConfig};
}
modelAndWeightsConfig = modelAndWeightsConfig;
modelTopology = modelAndWeightsConfig.modelTopology;
if (modelTopology["model_config"] != null) {
modelTopology = modelTopology["model_config"];
}
tsConfig = convertPythonicToTs(modelTopology);
model2 = deserialize(tsConfig, customObjects);
if (!(modelAndWeightsConfig.weightsManifest != null))
return [3, 2];
return [4, tfc.io.loadWeights(modelAndWeightsConfig.weightsManifest, modelAndWeightsConfig.pathPrefix, model2.weights.map(function(weight2) {
return weight2.originalName;
}))];
case 1:
weightValues = _b.sent();
uniqueWeightValues = {};
for (_i = 0, _a = model2.weights; _i < _a.length; _i++) {
weight = _a[_i];
uniqueWeightValues[weight.originalName] = weightValues[weight.originalName];
}
model2.loadWeights(uniqueWeightValues);
tfc.dispose(weightValues);
_b.label = 2;
case 2:
return [2, model2];
}
});
});
}
function loadLayersModelInternal(pathOrIOHandler, options) {
return __awaiter(this, void 0, void 0, function() {
var handlers;
return __generator(this, function(_a) {
if (options == null) {
options = {};
}
if (typeof pathOrIOHandler === "string") {
handlers = tfc.io.getLoadHandlers(pathOrIOHandler, options);
if (handlers.length === 0) {
handlers.push(tfc.io.browserHTTPRequest(pathOrIOHandler, options));
} else if (handlers.length > 1) {
throw new ValueError("Found more than one (" + handlers.length + ") load handlers for " + ("URL '" + pathOrIOHandler + "'"));
}
pathOrIOHandler = handlers[0];
}
return [2, loadLayersModelFromIOHandler(pathOrIOHandler, void 0, options)];
});
});
}
function loadLayersModelFromIOHandler(handler, customObjects, options) {
return __awaiter(this, void 0, void 0, function() {
var artifacts, modelTopology, strict, fastWeightInit, model2, trainingConfig, _a, modelWeights, optimizerWeights;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (options == null) {
options = {};
}
if (handler.load == null) {
throw new ValueError("Cannot proceed with model loading because the IOHandler provided does not have the `load` method implemented.");
}
return [4, handler.load()];
case 1:
artifacts = _b.sent();
modelTopology = artifacts.modelTopology;
if (modelTopology["model_config"] != null) {
modelTopology = modelTopology["model_config"];
}
strict = options.strict == null ? true : options.strict;
fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict;
model2 = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit);
trainingConfig = artifacts.trainingConfig;
if (trainingConfig != null) {
model2.loadTrainingConfig(trainingConfig);
}
if (artifacts.userDefinedMetadata != null) {
model2.setUserDefinedMetadata(artifacts.userDefinedMetadata);
}
if (!(artifacts.weightData != null))
return [3, 4];
if (artifacts.weightSpecs == null) {
throw new ValueError("LayersModel artifacts contains weight data, but not weight specs. Therefore loading of weights cannot proceed.");
}
_a = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs), modelWeights = _a.modelWeights, optimizerWeights = _a.optimizerWeights;
model2.loadWeights(modelWeights, strict);
if (!(model2.optimizer != null && optimizerWeights.length > 0))
return [3, 3];
return [4, model2.optimizer.setWeights(optimizerWeights)];
case 2:
_b.sent();
_b.label = 3;
case 3:
tfc.dispose(modelWeights);
tfc.dispose(optimizerWeights.map(function(w) {
return w.tensor;
}));
_b.label = 4;
case 4:
return [2, model2];
}
});
});
}
function decodeModelAndOptimizerWeights(buffer, specs) {
var name2Tensor = tfc.io.decodeWeights(buffer, specs);
var modelWeights = {};
var optimizerWeights = [];
specs.forEach(function(spec) {
if (spec.group === "optimizer") {
optimizerWeights.push({name: spec.name, tensor: name2Tensor[spec.name]});
} else {
modelWeights[spec.name] = name2Tensor[spec.name];
}
});
return {modelWeights, optimizerWeights};
}
var Sequential = function(_super) {
__extends(Sequential2, _super);
function Sequential2(args) {
var _this = _super.call(this, {inputs: [], outputs: []}) || this;
args = args || {};
_this.trainable = true;
_this.built = false;
_this.name = args.name != null ? args.name : getUid("sequential_");
if (args.layers != null) {
for (var _i = 0, _a = args.layers; _i < _a.length; _i++) {
var layer = _a[_i];
_this.add(layer);
}
}
return _this;
}
Sequential2.prototype.checkShape = function(layer) {
var shape = layer.inboundNodes[0].outputTensors[0].shape;
if (shape.some(function(x) {
return x < 0;
})) {
throw new ValueError("Negative dimension size caused by adding layer " + (layer.name + " with input shape [") + (layer.inboundNodes[0].inputTensors[0].shape + "]"));
}
};
Sequential2.prototype.add = function(layer) {
var isLayerModelInstance = layer instanceof Sequential2 || layer instanceof LayersModel;
var modelLayer;
if (isLayerModelInstance) {
modelLayer = layer;
if (modelLayer.outputs.length !== 1) {
throw new ValueError("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");
}
if (modelLayer.inputs.length !== 1) {
throw new ValueError("All layers in a Sequential model should have a single input tensor. For multi-input layers, use the functional API.");
}
}
if (this.outputs.length === 0) {
if (layer.inboundNodes.length === 0) {
if (layer.batchInputShape == null) {
throw new ValueError("The first layer in a Sequential model must get an `inputShape` or `batchInputShape` argument.");
}
var x = Input({
batchShape: layer.batchInputShape,
dtype: layer.dtype,
name: layer.name + "_input"
});
layer.apply(x);
}
if (isLayerModelInstance) {
this.outputs = modelLayer.outputs;
this.inputs = modelLayer.inputs;
} else {
if (layer.inboundNodes.length !== 1) {
throw new ValueError("A layer added to a Sequential model must not already be " + ("connected somewhere else. LayersModel received layer " + layer.name + " ") + ("which has " + layer.inboundNodes.length + " pre-existing inbound ") + "connections.");
}
if (layer.inboundNodes[0].outputTensors.length !== 1) {
throw new ValueError("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");
}
this.checkShape(layer);
this.outputs = [layer.inboundNodes[0].outputTensors[0]];
this.inputs = getSourceInputs(this.outputs[0]);
}
this.inboundNodes = [];
new Node({
outboundLayer: this,
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: this.inputs,
outputTensors: this.outputs,
inputMasks: pyListRepeat(null, this.inputs.length),
outputMasks: [null],
inputShapes: this.inputs.map(function(x2) {
return x2.shape;
}),
outputShapes: this.outputs[0].shape
});
} else {
var outputTensor = layer.apply(this.outputs[0]);
if (Array.isArray(outputTensor)) {
throw new TypeError("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");
}
this.checkShape(layer);
this.outputs = [outputTensor];
this.inboundNodes[0].outputTensors = this.outputs;
this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
}
this.layers.push(layer);
this.built = false;
};
Sequential2.prototype.pop = function() {
if (this.layers.length === 0) {
throw new TypeError("There are no layers in the model.");
}
this.layers.pop();
if (this.layers.length === 0) {
this.outputs = [];
this.inboundNodes = [];
this.outboundNodes = [];
} else {
var lastLayerIndex = this.layers.length - 1;
this.layers[lastLayerIndex].outboundNodes = [];
this.outputs = [this.layers[lastLayerIndex].output];
this.inboundNodes[0].outputTensors = this.outputs;
this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
}
};
Sequential2.prototype.call = function(inputs, kwargs) {
if (this.model == null) {
this.build();
}
return this.model.call(inputs, kwargs);
};
Sequential2.prototype.build = function(inputShape) {
getExactlyOneShape(inputShape);
if (this.inputs.length === 0 || this.outputs.length === 0) {
throw new TypeError("Sequential model cannot be built: model is empty. Add some layers first.");
}
this.model = new LayersModel({
inputs: this.inputs,
outputs: this.outputs[0],
name: this.name + "_model"
});
this.model.trainable = this.trainable;
this.supportsMasking = this.model.supportsMasking;
this.inputLayers = this.model.inputLayers;
this.inputLayersNodeIndices = this.model.inputLayersNodeIndices;
this.inputLayersTensorIndices = this.model.inputLayersTensorIndices;
this.outputLayers = this.model.outputLayers;
this.outputLayersNodeIndices = this.model.outputLayersNodeIndices;
this.outputLayersTensorIndices = this.model.outputLayersTensorIndices;
this.nodesByDepth = this.model.nodesByDepth;
this.containerNodes = this.model.containerNodes;
this.outputNames = this.model.outputNames;
this.inputNames = this.model.inputNames;
this.built = true;
};
Sequential2.prototype.countParams = function() {
if (!this.built) {
this.build();
}
return _super.prototype.countParams.call(this);
};
Sequential2.prototype.summary = function(lineLength, positions, printFn) {
if (printFn === void 0) {
printFn = console.log;
}
if (!this.built) {
this.build();
}
_super.prototype.summary.call(this, lineLength, positions, printFn);
};
Sequential2.prototype.setWeights = function(weights) {
if (this.model == null) {
this.build();
}
this.model.setWeights(weights);
};
Sequential2.prototype.evaluate = function(x, y, args) {
if (args === void 0) {
args = {};
}
if (!this.built) {
throw new RuntimeError("The model needs to be compiled before being used.");
}
return this.model.evaluate(x, y, args);
};
Sequential2.prototype.evaluateDataset = function(dataset, args) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
if (!this.built) {
throw new RuntimeError("The model needs to be compiled before being used.");
}
return [2, this.model.evaluateDataset(dataset, args)];
});
});
};
Sequential2.prototype.predict = function(x, args) {
if (args === void 0) {
args = {};
}
if (this.model == null) {
this.build();
}
return this.model.predict(x, args);
};
Sequential2.prototype.predictOnBatch = function(x) {
if (this.model == null) {
this.build();
}
return this.model.predictOnBatch(x);
};
Sequential2.prototype.compile = function(args) {
this.build();
this.model.compile(args);
this.optimizer_ = this.model.optimizer;
this.isOptimizerOwned = this.model.isOptimizerOwned;
this.loss = this.model.loss;
this.metrics = this.model.metrics;
this.metricsTensors = this.model.metricsTensors;
this.metricsNames = this.model.metricsNames;
};
Object.defineProperty(Sequential2.prototype, "optimizer", {
get: function() {
return this.model == null ? void 0 : this.model.optimizer;
},
set: function(optimizer) {
this.model.optimizer = optimizer;
},
enumerable: true,
configurable: true
});
Sequential2.prototype.fit = function(x, y, args) {
if (args === void 0) {
args = {};
}
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
if (!this.built) {
throw new RuntimeError("The model needs to be compiled before being used.");
}
return [2, this.model.fit(x, y, args)];
});
});
};
Sequential2.prototype.fitDataset = function(dataset, args) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
if (!this.built) {
throw new RuntimeError("The model needs to be compiled before being used.");
}
return [2, this.model.fitDataset(dataset, args)];
});
});
};
Sequential2.prototype.trainOnBatch = function(x, y) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.model.trainOnBatch(x, y)];
});
});
};
Sequential2.fromConfig = function(cls, config, customObjects, fastWeightInit) {
if (fastWeightInit === void 0) {
fastWeightInit = false;
}
var configArray;
var extraModelConfig = {};
if (config instanceof Array) {
if (!(config[0].className != null) || config[0]["className"] === "Merge") {
throw new ValueError("Legacy serialization format not supported yet.");
}
configArray = config;
} else {
tfc.util.assert(config["layers"] != null, function() {
return "When the config data for a Sequential model is not an Array, it must be an Object that contains the 'layers' field.";
});
configArray = config["layers"];
delete config["layers"];
extraModelConfig = config;
}
var model2 = new cls(extraModelConfig);
if (!(model2 instanceof Sequential2)) {
throw new NotImplementedError("Sequential.fromConfig called on non-Sequential input: " + model2);
}
for (var _i = 0, configArray_1 = configArray; _i < configArray_1.length; _i++) {
var conf = configArray_1[_i];
var customObjects_1 = void 0;
var layer = deserialize(conf, customObjects_1, fastWeightInit);
if (fastWeightInit) {
layer.setFastWeightInitDuringBuild(true);
}
model2.add(layer);
}
return model2;
};
Object.defineProperty(Sequential2.prototype, "stopTraining", {
get: function() {
if (this.model == null) {
throw new ValueError("Cannot get the stopTraining property of a sequential model before it is compiled.");
}
return this.model.stopTraining;
},
set: function(stop) {
if (this.model == null) {
throw new ValueError("Cannot set the stopTraining property of a sequential model before it is compiled.");
}
this.model.stopTraining = stop;
},
enumerable: true,
configurable: true
});
Sequential2.prototype.getConfig = function() {
var layers = [];
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) {
var layer = _a[_i];
var dict = {};
dict["className"] = layer.getClassName();
dict["config"] = layer.getConfig();
layers.push(dict);
}
return {name: this.name, layers};
};
Sequential2.className = "Sequential";
return Sequential2;
}(LayersModel);
tfc.serialization.registerClass(Sequential);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function model(args) {
return new LayersModel(args);
}
function sequential(config) {
return new Sequential(config);
}
function loadLayersModel(pathOrIOHandler, options) {
if (options == null) {
options = {};
}
return loadLayersModelInternal(pathOrIOHandler, options);
}
function input(config) {
return Input(config);
}
function registerCallbackConstructor(verbosityLevel, callbackConstructor) {
CallbackConstructorRegistry.registerCallbackConstructor(verbosityLevel, callbackConstructor);
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var Activation = function(_super) {
__extends(Activation2, _super);
function Activation2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Activation2.prototype.getConfig = function() {
return {};
};
return Activation2;
}(tfc.serialization.Serializable);
var Elu = function(_super) {
__extends(Elu2, _super);
function Elu2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Elu2.prototype.apply = function(x, alpha) {
if (alpha === void 0) {
alpha = 1;
}
return elu(x, alpha);
};
Elu2.className = "elu";
return Elu2;
}(Activation);
tfc.serialization.registerClass(Elu);
var Selu = function(_super) {
__extends(Selu2, _super);
function Selu2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Selu2.prototype.apply = function(x) {
return tfc.selu(x);
};
Selu2.className = "selu";
return Selu2;
}(Activation);
tfc.serialization.registerClass(Selu);
var Relu = function(_super) {
__extends(Relu2, _super);
function Relu2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Relu2.prototype.apply = function(x) {
return tfc.relu(x);
};
Relu2.className = "relu";
return Relu2;
}(Activation);
tfc.serialization.registerClass(Relu);
var Relu6 = function(_super) {
__extends(Relu62, _super);
function Relu62() {
return _super !== null && _super.apply(this, arguments) || this;
}
Relu62.prototype.apply = function(x) {
return tfc.tidy(function() {
return tfc.minimum(6, tfc.relu(x));
});
};
Relu62.className = "relu6";
return Relu62;
}(Activation);
tfc.serialization.registerClass(Relu6);
var Linear = function(_super) {
__extends(Linear2, _super);
function Linear2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Linear2.prototype.apply = function(x) {
return x;
};
Linear2.className = "linear";
return Linear2;
}(Activation);
tfc.serialization.registerClass(Linear);
var Sigmoid = function(_super) {
__extends(Sigmoid2, _super);
function Sigmoid2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Sigmoid2.prototype.apply = function(x) {
return tfc.sigmoid(x);
};
Sigmoid2.className = "sigmoid";
return Sigmoid2;
}(Activation);
tfc.serialization.registerClass(Sigmoid);
var HardSigmoid = function(_super) {
__extends(HardSigmoid2, _super);
function HardSigmoid2() {
return _super !== null && _super.apply(this, arguments) || this;
}
HardSigmoid2.prototype.apply = function(x) {
return hardSigmoid(x);
};
HardSigmoid2.className = "hardSigmoid";
return HardSigmoid2;
}(Activation);
tfc.serialization.registerClass(HardSigmoid);
var Softplus = function(_super) {
__extends(Softplus2, _super);
function Softplus2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Softplus2.prototype.apply = function(x) {
return tfc.softplus(x);
};
Softplus2.className = "softplus";
return Softplus2;
}(Activation);
tfc.serialization.registerClass(Softplus);
var Softsign = function(_super) {
__extends(Softsign2, _super);
function Softsign2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Softsign2.prototype.apply = function(x) {
return softsign(x);
};
Softsign2.className = "softsign";
return Softsign2;
}(Activation);
tfc.serialization.registerClass(Softsign);
var Tanh = function(_super) {
__extends(Tanh2, _super);
function Tanh2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Tanh2.prototype.apply = function(x) {
return tfc.tanh(x);
};
Tanh2.className = "tanh";
return Tanh2;
}(Activation);
tfc.serialization.registerClass(Tanh);
var Softmax = function(_super) {
__extends(Softmax2, _super);
function Softmax2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Softmax2.prototype.apply = function(x, axis) {
if (axis === void 0) {
axis = -1;
}
return tfc.softmax(x, axis);
};
Softmax2.className = "softmax";
return Softmax2;
}(Activation);
tfc.serialization.registerClass(Softmax);
var LogSoftmax = function(_super) {
__extends(LogSoftmax2, _super);
function LogSoftmax2() {
return _super !== null && _super.apply(this, arguments) || this;
}
LogSoftmax2.prototype.apply = function(x, axis) {
if (axis === void 0) {
axis = -1;
}
return tfc.logSoftmax(x, axis);
};
LogSoftmax2.className = "logSoftmax";
return LogSoftmax2;
}(Activation);
tfc.serialization.registerClass(LogSoftmax);
var Swish = function(_super) {
__extends(Swish2, _super);
function Swish2() {
return _super !== null && _super.apply(this, arguments) || this;
}
Swish2.prototype.apply = function(x, alpha) {
if (alpha === void 0) {
alpha = 1;
}
return tfc.tidy(function() {
return tfc.sigmoid(x.mul(alpha)).mul(x);
});
};
Swish2.className = "swish";
return Swish2;
}(Activation);
tfc.serialization.registerClass(Swish);
function serializeActivation(activation2) {
return activation2.getClassName();
}
function deserializeActivation(config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, "activation");
}
function getActivation(identifier) {
if (identifier == null) {
var config = {};
config["className"] = "linear";
config["config"] = {};
return deserializeActivation(config);
}
if (typeof identifier === "string") {
var config = {};
config["className"] = identifier;
config["config"] = {};
return deserializeActivation(config);
} else if (identifier instanceof Activation) {
return identifier;
} else {
return deserializeActivation(identifier);
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function assertObjectArgs(args) {
if (args != null && typeof args !== "object") {
throw new Error("Argument to L1L2 regularizer's constructor is expected to be an " + ("object, but received: " + args));
}
}
var Regularizer = function(_super) {
__extends(Regularizer2, _super);
function Regularizer2() {
return _super !== null && _super.apply(this, arguments) || this;
}
return Regularizer2;
}(tfc.serialization.Serializable);
var L1L2 = function(_super) {
__extends(L1L22, _super);
function L1L22(args) {
var _this = _super.call(this) || this;
assertObjectArgs(args);
_this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
_this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
_this.hasL1 = _this.l1 !== 0;
_this.hasL2 = _this.l2 !== 0;
return _this;
}
L1L22.prototype.apply = function(x) {
var _this = this;
return tfc.tidy(function() {
var regularization = tfc.zeros([1]);
if (_this.hasL1) {
regularization = tfc.add(regularization, tfc.sum(tfc.mul(_this.l1, tfc.abs(x))));
}
if (_this.hasL2) {
regularization = tfc.add(regularization, tfc.sum(tfc.mul(_this.l2, square(x))));
}
return regularization.asScalar();
});
};
L1L22.prototype.getConfig = function() {
return {l1: this.l1, l2: this.l2};
};
L1L22.fromConfig = function(cls, config) {
return new cls({l1: config["l1"], l2: config["l2"]});
};
L1L22.className = "L1L2";
return L1L22;
}(Regularizer);
tfc.serialization.registerClass(L1L2);
function l1(args) {
assertObjectArgs(args);
return new L1L2({l1: args != null ? args.l1 : null, l2: 0});
}
function l2(args) {
assertObjectArgs(args);
return new L1L2({l2: args != null ? args.l2 : null, l1: 0});
}
var REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
l1l2: "L1L2"
};
function serializeRegularizer(constraint) {
return serializeKerasObject(constraint);
}
function deserializeRegularizer(config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, "regularizer");
}
function getRegularizer(identifier) {
if (identifier == null) {
return null;
}
if (typeof identifier === "string") {
var className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
var config = {className, config: {}};
return deserializeRegularizer(config);
} else if (identifier instanceof Regularizer) {
return identifier;
} else {
return deserializeRegularizer(identifier);
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var ReLU = function(_super) {
__extends(ReLU2, _super);
function ReLU2(args) {
var _this = _super.call(this, args == null ? {} : args) || this;
_this.supportsMasking = true;
if (args != null) {
_this.maxValue = args.maxValue;
}
return _this;
}
ReLU2.prototype.call = function(inputs, kwargs) {
inputs = getExactlyOneTensor(inputs);
var output = tfc.relu(inputs);
if (this.maxValue != null) {
output = tfc.clipByValue(output, 0, this.maxValue);
}
return output;
};
ReLU2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
ReLU2.prototype.getConfig = function() {
var config = {maxValue: this.maxValue};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
ReLU2.className = "ReLU";
return ReLU2;
}(Layer);
tfc.serialization.registerClass(ReLU);
var LeakyReLU = function(_super) {
__extends(LeakyReLU2, _super);
function LeakyReLU2(args) {
var _this = _super.call(this, args == null ? {} : args) || this;
_this.DEFAULT_ALPHA = 0.3;
if (args == null) {
args = {};
}
_this.alpha = args.alpha == null ? _this.DEFAULT_ALPHA : args.alpha;
return _this;
}
LeakyReLU2.prototype.call = function(inputs, kwargs) {
var x = getExactlyOneTensor(inputs);
return tfc.leakyRelu(x, this.alpha);
};
LeakyReLU2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
LeakyReLU2.prototype.getConfig = function() {
var config = {alpha: this.alpha};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
LeakyReLU2.className = "LeakyReLU";
return LeakyReLU2;
}(Layer);
tfc.serialization.registerClass(LeakyReLU);
var PReLU = function(_super) {
__extends(PReLU2, _super);
function PReLU2(args) {
var _this = _super.call(this, args == null ? {} : args) || this;
_this.DEFAULT_ALPHA_INITIALIZER = "zeros";
if (args == null) {
args = {};
}
_this.supportsMasking = true;
_this.alphaInitializer = getInitializer(args.alphaInitializer || _this.DEFAULT_ALPHA_INITIALIZER);
_this.alphaRegularizer = getRegularizer(args.alphaRegularizer);
_this.alphaConstraint = getConstraint(args.alphaConstraint);
if (args.sharedAxes == null) {
_this.sharedAxes = null;
} else if (Array.isArray(args.sharedAxes)) {
_this.sharedAxes = args.sharedAxes;
} else if (typeof args.sharedAxes === "number") {
_this.sharedAxes = [args.sharedAxes];
} else {
throw new ValueError("Expected sharedAxes to be a number or an array of numbers, " + ("but got " + args.sharedAxes));
}
return _this;
}
PReLU2.prototype.build = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var paramShape = inputShape.slice(1);
if (this.sharedAxes != null) {
for (var _i = 0, _a = this.sharedAxes; _i < _a.length; _i++) {
var i = _a[_i];
paramShape[i - 1] = 1;
}
}
this.alpha = this.addWeight("alpha", paramShape, "float32", this.alphaInitializer, this.alphaRegularizer, true, this.alphaConstraint);
var axes = {};
if (this.sharedAxes != null) {
for (var i = 1; i < inputShape.length; ++i) {
axes[i] = inputShape[i];
}
}
this.inputSpec = [new InputSpec({
ndim: inputShape.length,
axes
})];
this.built = true;
};
PReLU2.prototype.call = function(inputs, kwargs) {
inputs = getExactlyOneTensor(inputs);
return tfc.prelu(inputs, this.alpha.read());
};
PReLU2.prototype.getConfig = function() {
var config = {
alphaInitializer: serializeInitializer(this.alphaInitializer),
alphaRegularizer: serializeRegularizer(this.alphaRegularizer),
alphaConstraint: serializeConstraint(this.alphaConstraint),
sharedAxes: this.sharedAxes
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
PReLU2.className = "PReLU";
return PReLU2;
}(Layer);
tfc.serialization.registerClass(PReLU);
var ELU = function(_super) {
__extends(ELU2, _super);
function ELU2(args) {
var _this = _super.call(this, args == null ? {} : args) || this;
_this.DEFAULT_ALPHA = 1;
if (args == null) {
args = {};
}
if (args.alpha != null && args.alpha !== _this.DEFAULT_ALPHA) {
throw new NotImplementedError("Non-default alpha value (" + args.alpha + ") is not supported by the ELU layer yet.");
}
_this.alpha = args.alpha == null ? _this.DEFAULT_ALPHA : args.alpha;
return _this;
}
ELU2.prototype.call = function(inputs, kwargs) {
var x = getExactlyOneTensor(inputs);
return tfc.elu(x);
};
ELU2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
ELU2.prototype.getConfig = function() {
var config = {alpha: this.alpha};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
ELU2.className = "ELU";
return ELU2;
}(Layer);
tfc.serialization.registerClass(ELU);
var ThresholdedReLU = function(_super) {
__extends(ThresholdedReLU2, _super);
function ThresholdedReLU2(args) {
var _this = _super.call(this, args == null ? {} : args) || this;
_this.DEFAULT_THETA = 1;
if (args == null) {
args = {};
}
_this.theta = args.theta == null ? _this.DEFAULT_THETA : args.theta;
return _this;
}
ThresholdedReLU2.prototype.call = function(inputs, kwargs) {
var x = getExactlyOneTensor(inputs);
return x.mul(cast(x.greater(this.theta), "float32"));
};
ThresholdedReLU2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
ThresholdedReLU2.prototype.getConfig = function() {
var config = {theta: this.theta};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
ThresholdedReLU2.className = "ThresholdedReLU";
return ThresholdedReLU2;
}(Layer);
tfc.serialization.registerClass(ThresholdedReLU);
var Softmax$1 = function(_super) {
__extends(Softmax$12, _super);
function Softmax$12(args) {
var _this = _super.call(this, args == null ? {} : args) || this;
_this.DEFAULT_AXIS = 1;
if (args == null) {
args = {};
}
_this.softmax = new Softmax().apply;
_this.axis = args.axis == null ? _this.DEFAULT_AXIS : args.axis;
return _this;
}
Softmax$12.prototype.call = function(inputs, kwargs) {
var x = getExactlyOneTensor(inputs);
return this.softmax(x, this.axis);
};
Softmax$12.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
Softmax$12.prototype.getConfig = function() {
var config = {axis: this.axis};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Softmax$12.className = "Softmax";
return Softmax$12;
}(Layer);
tfc.serialization.registerClass(Softmax$1);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function normalizeArray(value, n, name) {
if (typeof value === "number") {
return pyListRepeat(value, n);
} else {
if (value.length !== n) {
throw new ValueError("The " + name + " argument must be an integer or tuple of " + n + " integers." + (" Received: " + value.length + " elements."));
}
for (var i = 0; i < n; ++i) {
var singleValue = value[i];
if (!isInteger(singleValue)) {
throw new ValueError("The " + name + " argument must be an integer or tuple of " + n + (" integers. Received: " + JSON.stringify(value) + " including a") + (" non-integer number " + singleValue));
}
}
return value;
}
}
function convOutputLength(inputLength, filterSize, padding, stride, dilation) {
if (dilation === void 0) {
dilation = 1;
}
if (inputLength == null) {
return inputLength;
}
var dilatedFilterSize = filterSize + (filterSize - 1) * (dilation - 1);
var outputLength;
if (padding === "same") {
outputLength = inputLength;
} else {
outputLength = inputLength - dilatedFilterSize + 1;
}
return Math.floor((outputLength + stride - 1) / stride);
}
function deconvLength(dimSize, strideSize, kernelSize, padding) {
if (dimSize == null) {
return null;
}
if (padding === "valid") {
dimSize = dimSize * strideSize + max([kernelSize - strideSize, 0]);
} else if (padding === "same") {
dimSize = dimSize * strideSize;
} else {
throw new ValueError("Unsupport padding mode: " + padding + ".");
}
return dimSize;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function preprocessConv2DInput(x, dataFormat) {
return tfc.tidy(function() {
checkDataFormat(dataFormat);
if (dataFormat === "channelsFirst") {
return tfc.transpose(x, [0, 2, 3, 1]);
} else {
return x;
}
});
}
function preprocessConv3DInput(x, dataFormat) {
return tfc.tidy(function() {
checkDataFormat(dataFormat);
if (dataFormat === "channelsFirst") {
return tfc.transpose(x, [0, 2, 3, 4, 1]);
} else {
return x;
}
});
}
function conv1dWithBias(x, kernel, bias, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = 1;
}
if (padding === void 0) {
padding = "valid";
}
if (dilationRate === void 0) {
dilationRate = 1;
}
return tfc.tidy(function() {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
if (x.shape.length !== 3) {
throw new ValueError("The input of a conv1dWithBias operation should be 3, but is " + (x.shape.length + " instead."));
}
if (kernel.shape.length !== 3) {
throw new ValueError("The kernel for a conv1dWithBias operation should be 3, but is " + (kernel.shape.length + " instead"));
}
if (bias != null && bias.shape.length !== 1) {
throw new ValueError("The bias for a conv1dWithBias operation should be 1, but is " + (kernel.shape.length + " instead"));
}
if (dataFormat === "channelsFirst") {
x = tfc.transpose(x, [0, 2, 1]);
}
if (padding === "causal") {
throw new NotImplementedError("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");
}
var y = tfc.conv1d(x, kernel, strides, padding === "same" ? "same" : "valid", "NWC", dilationRate);
if (bias != null) {
y = biasAdd(y, bias);
}
return y;
});
}
function conv2dWithBiasActivation(x, kernel, bias, strides, padding, dataFormat, dilationRate, activation2) {
if (strides === void 0) {
strides = [1, 1];
}
if (padding === void 0) {
padding = "valid";
}
if (activation2 === void 0) {
activation2 = null;
}
return tfc.tidy(function() {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
if (x.rank !== 3 && x.rank !== 4) {
throw new ValueError("conv2dWithBiasActivation expects input to be of rank 3 or 4, " + ("but received " + x.rank + "."));
}
if (kernel.rank !== 3 && kernel.rank !== 4) {
throw new ValueError("conv2dWithBiasActivation expects kernel to be of rank 3 or 4, " + ("but received " + x.rank + "."));
}
var y = preprocessConv2DInput(x, dataFormat);
if (padding === "causal") {
throw new NotImplementedError("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");
}
y = tfc.fused.conv2d({
x: y,
filter: kernel,
strides,
pad: padding === "same" ? "same" : "valid",
dilations: dilationRate,
dataFormat: "NHWC",
bias,
activation: activation2
});
if (dataFormat === "channelsFirst") {
y = tfc.transpose(y, [0, 3, 1, 2]);
}
return y;
});
}
function conv3dWithBias(x, kernel, bias, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = [1, 1, 1];
}
if (padding === void 0) {
padding = "valid";
}
return tfc.tidy(function() {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
if (x.rank !== 4 && x.rank !== 5) {
throw new ValueError("conv3dWithBias expects input to be of rank 4 or 5, but received " + (x.rank + "."));
}
if (kernel.rank !== 4 && kernel.rank !== 5) {
throw new ValueError("conv3dWithBias expects kernel to be of rank 4 or 5, but received " + (x.rank + "."));
}
var y = preprocessConv3DInput(x, dataFormat);
if (padding === "causal") {
throw new NotImplementedError("The support for CAUSAL padding mode in conv3dWithBias is not implemented yet.");
}
y = tfc.conv3d(y, kernel, strides, padding === "same" ? "same" : "valid", "NDHWC", dilationRate);
if (bias != null) {
y = biasAdd(y, bias);
}
if (dataFormat === "channelsFirst") {
y = tfc.transpose(y, [0, 4, 1, 2, 3]);
}
return y;
});
}
var BaseConv = function(_super) {
__extends(BaseConv2, _super);
function BaseConv2(rank, args) {
var _this = _super.call(this, args) || this;
_this.bias = null;
_this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal";
_this.DEFAULT_BIAS_INITIALIZER = "zeros";
BaseConv2.verifyArgs(args);
_this.rank = rank;
assertPositiveInteger(_this.rank, "rank");
if (_this.rank !== 1 && _this.rank !== 2 && _this.rank !== 3) {
throw new NotImplementedError("Convolution layer for rank other than 1, 2, or 3 (" + _this.rank + ") is not implemented yet.");
}
_this.kernelSize = normalizeArray(args.kernelSize, rank, "kernelSize");
_this.strides = normalizeArray(args.strides == null ? 1 : args.strides, rank, "strides");
_this.padding = args.padding == null ? "valid" : args.padding;
checkPaddingMode(_this.padding);
_this.dataFormat = args.dataFormat == null ? "channelsLast" : args.dataFormat;
checkDataFormat(_this.dataFormat);
_this.activation = getActivation(args.activation);
_this.useBias = args.useBias == null ? true : args.useBias;
_this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER);
_this.biasConstraint = getConstraint(args.biasConstraint);
_this.biasRegularizer = getRegularizer(args.biasRegularizer);
_this.activityRegularizer = getRegularizer(args.activityRegularizer);
_this.dilationRate = normalizeArray(args.dilationRate == null ? 1 : args.dilationRate, rank, "dilationRate");
if (_this.rank === 1 && (Array.isArray(_this.dilationRate) && _this.dilationRate.length !== 1)) {
throw new ValueError("dilationRate must be a number or an array of a single number for 1D convolution, but received " + ("" + JSON.stringify(_this.dilationRate)));
} else if (_this.rank === 2) {
if (typeof _this.dilationRate === "number") {
_this.dilationRate = [_this.dilationRate, _this.dilationRate];
} else if (_this.dilationRate.length !== 2) {
throw new ValueError("dilationRate must be a number or array of two numbers for 2D " + ("convolution, but received " + JSON.stringify(_this.dilationRate)));
}
} else if (_this.rank === 3) {
if (typeof _this.dilationRate === "number") {
_this.dilationRate = [_this.dilationRate, _this.dilationRate, _this.dilationRate];
} else if (_this.dilationRate.length !== 3) {
throw new ValueError("dilationRate must be a number or array of three numbers for 3D " + ("convolution, but received " + JSON.stringify(_this.dilationRate)));
}
}
return _this;
}
BaseConv2.verifyArgs = function(args) {
assert("kernelSize" in args, "required key 'kernelSize' not in config");
if (typeof args.kernelSize !== "number" && !checkArrayTypeAndLength(args.kernelSize, "number", 1, 3)) {
throw new ValueError("BaseConv expects config.kernelSize to be number or number[] with " + ("length 1, 2, or 3, but received " + JSON.stringify(args.kernelSize) + "."));
}
};
BaseConv2.prototype.getConfig = function() {
var config = {
kernelSize: this.kernelSize,
strides: this.strides,
padding: this.padding,
dataFormat: this.dataFormat,
dilationRate: this.dilationRate,
activation: serializeActivation(this.activation),
useBias: this.useBias,
biasInitializer: serializeInitializer(this.biasInitializer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
biasConstraint: serializeConstraint(this.biasConstraint)
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return BaseConv2;
}(Layer);
var Conv = function(_super) {
__extends(Conv2, _super);
function Conv2(rank, args) {
var _this = _super.call(this, rank, args) || this;
_this.kernel = null;
Conv2.verifyArgs(args);
_this.filters = args.filters;
assertPositiveInteger(_this.filters, "filters");
_this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER);
_this.kernelConstraint = getConstraint(args.kernelConstraint);
_this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
return _this;
}
Conv2.prototype.build = function(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
var channelAxis = this.dataFormat === "channelsFirst" ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null) {
throw new ValueError("The channel dimension of the input should be defined. " + ("Found " + inputShape[channelAxis]));
}
var inputDim = inputShape[channelAxis];
var kernelShape = this.kernelSize.concat([inputDim, this.filters]);
this.kernel = this.addWeight("kernel", kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
if (this.useBias) {
this.bias = this.addWeight("bias", [this.filters], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
this.inputSpec = [{ndim: this.rank + 2, axes: (_a = {}, _a[channelAxis] = inputDim, _a)}];
this.built = true;
};
Conv2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = getExactlyOneTensor(inputs);
var outputs;
var biasValue = _this.bias == null ? null : _this.bias.read();
var fusedActivationName = mapActivationToFusedKernel(_this.activation.getClassName());
if (fusedActivationName != null && _this.rank === 2) {
outputs = conv2dWithBiasActivation(inputs, _this.kernel.read(), biasValue, _this.strides, _this.padding, _this.dataFormat, _this.dilationRate, fusedActivationName);
} else {
if (_this.rank === 1) {
outputs = conv1dWithBias(inputs, _this.kernel.read(), biasValue, _this.strides[0], _this.padding, _this.dataFormat, _this.dilationRate[0]);
} else if (_this.rank === 2) {
outputs = conv2dWithBiasActivation(inputs, _this.kernel.read(), biasValue, _this.strides, _this.padding, _this.dataFormat, _this.dilationRate);
} else if (_this.rank === 3) {
outputs = conv3dWithBias(inputs, _this.kernel.read(), biasValue, _this.strides, _this.padding, _this.dataFormat, _this.dilationRate);
} else {
throw new NotImplementedError("convolutions greater than 3D are not implemented yet.");
}
if (_this.activation != null) {
outputs = _this.activation.apply(outputs);
}
}
return outputs;
});
};
Conv2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var newSpace = [];
var space = this.dataFormat === "channelsLast" ? inputShape.slice(1, inputShape.length - 1) : inputShape.slice(2);
for (var i = 0; i < space.length; ++i) {
var newDim = convOutputLength(space[i], this.kernelSize[i], this.padding, this.strides[i], typeof this.dilationRate === "number" ? this.dilationRate : this.dilationRate[i]);
newSpace.push(newDim);
}
var outputShape = [inputShape[0]];
if (this.dataFormat === "channelsLast") {
outputShape = outputShape.concat(newSpace);
outputShape.push(this.filters);
} else {
outputShape.push(this.filters);
outputShape = outputShape.concat(newSpace);
}
return outputShape;
};
Conv2.prototype.getConfig = function() {
var config = {
filters: this.filters,
kernelInitializer: serializeInitializer(this.kernelInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint)
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Conv2.verifyArgs = function(args) {
if (!("filters" in args) || typeof args.filters !== "number" || args.filters < 1) {
throw new ValueError("Convolution layer expected config.filters to be a 'number' > 0 " + ("but got " + JSON.stringify(args.filters)));
}
};
return Conv2;
}(BaseConv);
var Conv2D = function(_super) {
__extends(Conv2D2, _super);
function Conv2D2(args) {
var _this = _super.call(this, 2, args) || this;
Conv2D2.verifyArgs(args);
return _this;
}
Conv2D2.prototype.getConfig = function() {
var config = _super.prototype.getConfig.call(this);
delete config["rank"];
return config;
};
Conv2D2.verifyArgs = function(args) {
if (typeof args.kernelSize !== "number" && !checkArrayTypeAndLength(args.kernelSize, "number", 1, 2)) {
throw new ValueError("Conv2D expects config.kernelSize to be number or number[] with " + ("length 1 or 2, but received " + JSON.stringify(args.kernelSize) + "."));
}
};
Conv2D2.className = "Conv2D";
return Conv2D2;
}(Conv);
tfc.serialization.registerClass(Conv2D);
var Conv3D = function(_super) {
__extends(Conv3D2, _super);
function Conv3D2(args) {
var _this = _super.call(this, 3, args) || this;
Conv3D2.verifyArgs(args);
return _this;
}
Conv3D2.prototype.getConfig = function() {
var config = _super.prototype.getConfig.call(this);
delete config["rank"];
return config;
};
Conv3D2.verifyArgs = function(args) {
if (typeof args.kernelSize !== "number") {
if (!(Array.isArray(args.kernelSize) && (args.kernelSize.length === 1 || args.kernelSize.length === 3))) {
throw new ValueError("Conv3D expects config.kernelSize to be number or" + (" [number, number, number], but received " + JSON.stringify(args.kernelSize) + "."));
}
}
};
Conv3D2.className = "Conv3D";
return Conv3D2;
}(Conv);
tfc.serialization.registerClass(Conv3D);
var Conv2DTranspose = function(_super) {
__extends(Conv2DTranspose2, _super);
function Conv2DTranspose2(args) {
var _this = _super.call(this, args) || this;
_this.inputSpec = [new InputSpec({ndim: 4})];
if (_this.padding !== "same" && _this.padding !== "valid") {
throw new ValueError("Conv2DTranspose currently supports only padding modes 'same' " + ("and 'valid', but received padding mode " + _this.padding));
}
return _this;
}
Conv2DTranspose2.prototype.build = function(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length !== 4) {
throw new ValueError("Input should have rank 4; Received input shape: " + JSON.stringify(inputShape));
}
var channelAxis = this.dataFormat === "channelsFirst" ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null) {
throw new ValueError("The channel dimension of the inputs should be defined. Found `None`.");
}
var inputDim = inputShape[channelAxis];
var kernelShape = this.kernelSize.concat([this.filters, inputDim]);
this.kernel = this.addWeight("kernel", kernelShape, "float32", this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
if (this.useBias) {
this.bias = this.addWeight("bias", [this.filters], "float32", this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
this.inputSpec = [new InputSpec({ndim: 4, axes: (_a = {}, _a[channelAxis] = inputDim, _a)})];
this.built = true;
};
Conv2DTranspose2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
var input2 = getExactlyOneTensor(inputs);
if (input2.shape.length !== 4) {
throw new ValueError("Conv2DTranspose.call() expects input tensor to be rank-4, but " + ("received a tensor of rank-" + input2.shape.length));
}
var inputShape = input2.shape;
var batchSize = inputShape[0];
var hAxis;
var wAxis;
if (_this.dataFormat === "channelsFirst") {
hAxis = 2;
wAxis = 3;
} else {
hAxis = 1;
wAxis = 2;
}
var height = inputShape[hAxis];
var width = inputShape[wAxis];
var kernelH = _this.kernelSize[0];
var kernelW = _this.kernelSize[1];
var strideH = _this.strides[0];
var strideW = _this.strides[1];
var outHeight = deconvLength(height, strideH, kernelH, _this.padding);
var outWidth = deconvLength(width, strideW, kernelW, _this.padding);
var outputShape = [batchSize, outHeight, outWidth, _this.filters];
if (_this.dataFormat !== "channelsLast") {
input2 = tfc.transpose(input2, [0, 2, 3, 1]);
}
var outputs = tfc.conv2dTranspose(input2, _this.kernel.read(), outputShape, _this.strides, _this.padding);
if (_this.dataFormat !== "channelsLast") {
outputs = tfc.transpose(outputs, [0, 3, 1, 2]);
}
if (_this.bias != null) {
outputs = biasAdd(outputs, _this.bias.read(), _this.dataFormat);
}
if (_this.activation != null) {
outputs = _this.activation.apply(outputs);
}
return outputs;
});
};
Conv2DTranspose2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var outputShape = inputShape.slice();
var channelAxis;
var heightAxis;
var widthAxis;
if (this.dataFormat === "channelsFirst") {
channelAxis = 1;
heightAxis = 2;
widthAxis = 3;
} else {
channelAxis = 3;
heightAxis = 1;
widthAxis = 2;
}
var kernelH = this.kernelSize[0];
var kernelW = this.kernelSize[1];
var strideH = this.strides[0];
var strideW = this.strides[1];
outputShape[channelAxis] = this.filters;
outputShape[heightAxis] = deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
outputShape[widthAxis] = deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
return outputShape;
};
Conv2DTranspose2.prototype.getConfig = function() {
var config = _super.prototype.getConfig.call(this);
delete config["dilationRate"];
return config;
};
Conv2DTranspose2.className = "Conv2DTranspose";
return Conv2DTranspose2;
}(Conv2D);
tfc.serialization.registerClass(Conv2DTranspose);
var SeparableConv = function(_super) {
__extends(SeparableConv2, _super);
function SeparableConv2(rank, config) {
var _this = _super.call(this, rank, config) || this;
_this.DEFAULT_DEPTHWISE_INITIALIZER = "glorotUniform";
_this.DEFAULT_POINTWISE_INITIALIZER = "glorotUniform";
_this.depthwiseKernel = null;
_this.pointwiseKernel = null;
if (config.filters == null) {
throw new ValueError("The `filters` configuration field is required by SeparableConv, but is unspecified.");
}
if (config.kernelInitializer != null || config.kernelRegularizer != null || config.kernelConstraint != null) {
throw new ValueError("Fields kernelInitializer, kernelRegularizer and kernelConstraint are invalid for SeparableConv2D. Use depthwiseInitializer, depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, pointwiseRegularizer and pointwiseConstraint instead.");
}
if (config.padding != null && config.padding !== "same" && config.padding !== "valid") {
throw new ValueError("SeparableConv" + _this.rank + "D supports only padding modes: " + ("'same' and 'valid', but received " + JSON.stringify(config.padding)));
}
_this.depthMultiplier = config.depthMultiplier == null ? 1 : config.depthMultiplier;
_this.depthwiseInitializer = getInitializer(config.depthwiseInitializer || _this.DEFAULT_DEPTHWISE_INITIALIZER);
_this.depthwiseRegularizer = getRegularizer(config.depthwiseRegularizer);
_this.depthwiseConstraint = getConstraint(config.depthwiseConstraint);
_this.pointwiseInitializer = getInitializer(config.depthwiseInitializer || _this.DEFAULT_POINTWISE_INITIALIZER);
_this.pointwiseRegularizer = getRegularizer(config.pointwiseRegularizer);
_this.pointwiseConstraint = getConstraint(config.pointwiseConstraint);
return _this;
}
SeparableConv2.prototype.build = function(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length < this.rank + 2) {
throw new ValueError("Inputs to SeparableConv" + this.rank + "D should have rank " + (this.rank + 2 + ", but received input shape: ") + ("" + JSON.stringify(inputShape)));
}
var channelAxis = this.dataFormat === "channelsFirst" ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
throw new ValueError("The channel dimension of the inputs should be defined, " + ("but found " + JSON.stringify(inputShape[channelAxis])));
}
var inputDim = inputShape[channelAxis];
var depthwiseKernelShape = this.kernelSize.concat([inputDim, this.depthMultiplier]);
var pointwiseKernelShape = [];
for (var i = 0; i < this.rank; ++i) {
pointwiseKernelShape.push(1);
}
pointwiseKernelShape.push(inputDim * this.depthMultiplier, this.filters);
var trainable = true;
this.depthwiseKernel = this.addWeight("depthwise_kernel", depthwiseKernelShape, "float32", this.depthwiseInitializer, this.depthwiseRegularizer, trainable, this.depthwiseConstraint);
this.pointwiseKernel = this.addWeight("pointwise_kernel", pointwiseKernelShape, "float32", this.pointwiseInitializer, this.pointwiseRegularizer, trainable, this.pointwiseConstraint);
if (this.useBias) {
this.bias = this.addWeight("bias", [this.filters], "float32", this.biasInitializer, this.biasRegularizer, trainable, this.biasConstraint);
} else {
this.bias = null;
}
this.inputSpec = [new InputSpec({ndim: this.rank + 2, axes: (_a = {}, _a[channelAxis] = inputDim, _a)})];
this.built = true;
};
SeparableConv2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = getExactlyOneTensor(inputs);
var output;
if (_this.rank === 1) {
throw new NotImplementedError("1D separable convolution is not implemented yet.");
} else if (_this.rank === 2) {
if (_this.dataFormat === "channelsFirst") {
inputs = tfc.transpose(inputs, [0, 2, 3, 1]);
}
output = tfc.separableConv2d(inputs, _this.depthwiseKernel.read(), _this.pointwiseKernel.read(), _this.strides, _this.padding, _this.dilationRate, "NHWC");
}
if (_this.useBias) {
output = biasAdd(output, _this.bias.read(), _this.dataFormat);
}
if (_this.activation != null) {
output = _this.activation.apply(output);
}
if (_this.dataFormat === "channelsFirst") {
output = tfc.transpose(output, [0, 3, 1, 2]);
}
return output;
});
};
SeparableConv2.prototype.getConfig = function() {
var config = _super.prototype.getConfig.call(this);
delete config["rank"];
delete config["kernelInitializer"];
delete config["kernelRegularizer"];
delete config["kernelConstraint"];
config["depthwiseInitializer"] = serializeInitializer(this.depthwiseInitializer);
config["pointwiseInitializer"] = serializeInitializer(this.pointwiseInitializer);
config["depthwiseRegularizer"] = serializeRegularizer(this.depthwiseRegularizer);
config["pointwiseRegularizer"] = serializeRegularizer(this.pointwiseRegularizer);
config["depthwiseConstraint"] = serializeConstraint(this.depthwiseConstraint);
config["pointwiseConstraint"] = serializeConstraint(this.pointwiseConstraint);
return config;
};
SeparableConv2.className = "SeparableConv";
return SeparableConv2;
}(Conv);
var SeparableConv2D = function(_super) {
__extends(SeparableConv2D2, _super);
function SeparableConv2D2(args) {
return _super.call(this, 2, args) || this;
}
SeparableConv2D2.className = "SeparableConv2D";
return SeparableConv2D2;
}(SeparableConv);
tfc.serialization.registerClass(SeparableConv2D);
var Conv1D = function(_super) {
__extends(Conv1D2, _super);
function Conv1D2(args) {
var _this = _super.call(this, 1, args) || this;
Conv1D2.verifyArgs(args);
_this.inputSpec = [{ndim: 3}];
return _this;
}
Conv1D2.prototype.getConfig = function() {
var config = _super.prototype.getConfig.call(this);
delete config["rank"];
delete config["dataFormat"];
return config;
};
Conv1D2.verifyArgs = function(args) {
if (typeof args.kernelSize !== "number" && !checkArrayTypeAndLength(args.kernelSize, "number", 1, 1)) {
throw new ValueError("Conv1D expects config.kernelSize to be number or number[] with " + ("length 1, but received " + JSON.stringify(args.kernelSize) + "."));
}
};
Conv1D2.className = "Conv1D";
return Conv1D2;
}(Conv);
tfc.serialization.registerClass(Conv1D);
var Cropping2D = function(_super) {
__extends(Cropping2D2, _super);
function Cropping2D2(args) {
var _this = _super.call(this, args) || this;
if (typeof args.cropping === "number") {
_this.cropping = [[args.cropping, args.cropping], [args.cropping, args.cropping]];
} else if (typeof args.cropping[0] === "number") {
_this.cropping = [
[args.cropping[0], args.cropping[0]],
[args.cropping[1], args.cropping[1]]
];
} else {
_this.cropping = args.cropping;
}
_this.dataFormat = args.dataFormat === void 0 ? "channelsLast" : args.dataFormat;
_this.inputSpec = [{ndim: 4}];
return _this;
}
Cropping2D2.prototype.computeOutputShape = function(inputShape) {
if (this.dataFormat === "channelsFirst") {
return [
inputShape[0],
inputShape[1],
inputShape[2] - this.cropping[0][0] - this.cropping[0][1],
inputShape[3] - this.cropping[1][0] - this.cropping[1][1]
];
} else {
return [
inputShape[0],
inputShape[1] - this.cropping[0][0] - this.cropping[0][1],
inputShape[2] - this.cropping[1][0] - this.cropping[1][1],
inputShape[3]
];
}
};
Cropping2D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = getExactlyOneTensor(inputs);
if (_this.dataFormat === "channelsLast") {
var hSliced = sliceAlongAxis(inputs, _this.cropping[0][0], inputs.shape[1] - _this.cropping[0][0] - _this.cropping[0][1], 2);
return sliceAlongAxis(hSliced, _this.cropping[1][0], inputs.shape[2] - _this.cropping[1][1] - _this.cropping[1][0], 3);
} else {
var hSliced = sliceAlongAxis(inputs, _this.cropping[0][0], inputs.shape[2] - _this.cropping[0][0] - _this.cropping[0][1], 3);
return sliceAlongAxis(hSliced, _this.cropping[1][0], inputs.shape[3] - _this.cropping[1][1] - _this.cropping[1][0], 4);
}
});
};
Cropping2D2.prototype.getConfig = function() {
var config = {cropping: this.cropping, dataFormat: this.dataFormat};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Cropping2D2.className = "Cropping2D";
return Cropping2D2;
}(Layer);
tfc.serialization.registerClass(Cropping2D);
var UpSampling2D = function(_super) {
__extends(UpSampling2D2, _super);
function UpSampling2D2(args) {
var _this = _super.call(this, args) || this;
_this.DEFAULT_SIZE = [2, 2];
_this.inputSpec = [{ndim: 4}];
_this.size = args.size == null ? _this.DEFAULT_SIZE : args.size;
_this.dataFormat = args.dataFormat == null ? "channelsLast" : args.dataFormat;
return _this;
}
UpSampling2D2.prototype.computeOutputShape = function(inputShape) {
if (this.dataFormat === "channelsFirst") {
var height = inputShape[2] == null ? null : this.size[0] * inputShape[2];
var width = inputShape[3] == null ? null : this.size[1] * inputShape[3];
return [inputShape[0], inputShape[1], height, width];
} else {
var height = inputShape[1] == null ? null : this.size[0] * inputShape[1];
var width = inputShape[2] == null ? null : this.size[1] * inputShape[2];
return [inputShape[0], height, width, inputShape[3]];
}
};
UpSampling2D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
var input2 = getExactlyOneTensor(inputs);
var inputShape = input2.shape;
if (_this.dataFormat === "channelsFirst") {
input2 = tfc.transpose(input2, [0, 2, 3, 1]);
var height = _this.size[0] * inputShape[2];
var width = _this.size[1] * inputShape[3];
var resized = input2.resizeNearestNeighbor([height, width]);
return tfc.transpose(resized, [0, 3, 1, 2]);
} else {
var height = _this.size[0] * inputShape[1];
var width = _this.size[1] * inputShape[2];
return input2.resizeNearestNeighbor([height, width]);
}
});
};
UpSampling2D2.prototype.getConfig = function() {
var config = {size: this.size, dataFormat: this.dataFormat};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
UpSampling2D2.className = "UpSampling2D";
return UpSampling2D2;
}(Layer);
tfc.serialization.registerClass(UpSampling2D);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function depthwiseConv2d(x, depthwiseKernel, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = [1, 1];
}
if (padding === void 0) {
padding = "valid";
}
return tfc.tidy(function() {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
var y = preprocessConv2DInput(x, dataFormat);
if (x.rank !== 4) {
throw new ValueError("Input for depthwiseConv2d is required to be 4-D, but is instead " + (x.rank + "-D"));
}
if (depthwiseKernel.rank !== 4) {
throw new ValueError("depthwiseKernel is required to be 4-D, but is instead " + (depthwiseKernel.rank + "-D"));
}
y = tfc.depthwiseConv2d(y, depthwiseKernel, strides, padding === "same" ? "same" : "valid", "NHWC", dilationRate);
if (dataFormat === "channelsFirst") {
y = tfc.transpose(y, [0, 3, 1, 2]);
}
return y;
});
}
var DepthwiseConv2D = function(_super) {
__extends(DepthwiseConv2D2, _super);
function DepthwiseConv2D2(args) {
var _this = _super.call(this, 2, args) || this;
_this.depthwiseKernel = null;
_this.depthMultiplier = args.depthMultiplier == null ? 1 : args.depthMultiplier;
_this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || _this.DEFAULT_KERNEL_INITIALIZER);
_this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);
_this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);
return _this;
}
DepthwiseConv2D2.prototype.build = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length < 4) {
throw new ValueError("Inputs to DepthwiseConv2D should have rank 4. " + ("Received input shape: " + JSON.stringify(inputShape) + "."));
}
var channelAxis = this.dataFormat === "channelsFirst" ? 1 : 3;
if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
throw new ValueError("The channel dimension of the inputs to DepthwiseConv2D should " + ("be defined, but is not (" + inputShape[channelAxis] + ")."));
}
var inputDim = inputShape[channelAxis];
var depthwiseKernelShape = [
this.kernelSize[0],
this.kernelSize[1],
inputDim,
this.depthMultiplier
];
this.depthwiseKernel = this.addWeight("depthwise_kernel", depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint);
if (this.useBias) {
this.bias = this.addWeight("bias", [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} else {
this.bias = null;
}
this.built = true;
};
DepthwiseConv2D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = getExactlyOneTensor(inputs);
var outputs = depthwiseConv2d(inputs, _this.depthwiseKernel.read(), _this.strides, _this.padding, _this.dataFormat, null);
if (_this.useBias) {
outputs = biasAdd(outputs, _this.bias.read(), _this.dataFormat);
}
if (_this.activation != null) {
outputs = _this.activation.apply(outputs);
}
return outputs;
});
};
DepthwiseConv2D2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var rows = this.dataFormat === "channelsFirst" ? inputShape[2] : inputShape[1];
var cols = this.dataFormat === "channelsFirst" ? inputShape[3] : inputShape[2];
var outFilters = this.dataFormat === "channelsFirst" ? inputShape[1] * this.depthMultiplier : inputShape[3] * this.depthMultiplier;
var outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]);
var outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]);
if (this.dataFormat === "channelsFirst") {
return [inputShape[0], outFilters, outRows, outCols];
} else {
return [inputShape[0], outRows, outCols, outFilters];
}
};
DepthwiseConv2D2.prototype.getConfig = function() {
var config = _super.prototype.getConfig.call(this);
config["depthMultiplier"] = this.depthMultiplier;
config["depthwiseInitializer"] = serializeInitializer(this.depthwiseInitializer);
config["depthwiseRegularizer"] = serializeRegularizer(this.depthwiseRegularizer);
config["depthwiseConstraint"] = serializeConstraint(this.depthwiseRegularizer);
return config;
};
DepthwiseConv2D2.className = "DepthwiseConv2D";
return DepthwiseConv2D2;
}(BaseConv);
tfc.serialization.registerClass(DepthwiseConv2D);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function standardizeArgs(inputs, initialState, constants, numConstants) {
if (Array.isArray(inputs)) {
if (initialState != null || constants != null) {
throw new ValueError("When inputs is an array, neither initialState or constants should be provided");
}
if (numConstants != null) {
constants = inputs.slice(inputs.length - numConstants, inputs.length);
inputs = inputs.slice(0, inputs.length - numConstants);
}
if (inputs.length > 1) {
initialState = inputs.slice(1, inputs.length);
}
inputs = inputs[0];
}
function toListOrNull(x) {
if (x == null || Array.isArray(x)) {
return x;
} else {
return [x];
}
}
initialState = toListOrNull(initialState);
constants = toListOrNull(constants);
return {inputs, initialState, constants};
}
function rnn(stepFunction, inputs, initialStates, goBackwards, mask, constants, unroll, needPerStepOutputs) {
if (goBackwards === void 0) {
goBackwards = false;
}
if (unroll === void 0) {
unroll = false;
}
if (needPerStepOutputs === void 0) {
needPerStepOutputs = false;
}
return tfc.tidy(function() {
var ndim = inputs.shape.length;
if (ndim < 3) {
throw new ValueError("Input should be at least 3D, but is " + ndim + "D.");
}
var axes = [1, 0].concat(range(2, ndim));
inputs = tfc.transpose(inputs, axes);
if (constants != null) {
throw new NotImplementedError("The rnn() functoin of the deeplearn.js backend does not support constants yet.");
}
if (unroll) {
console.warn("Backend rnn(): the unroll = true option is not applicable to the imperative deeplearn.js backend.");
}
if (mask != null) {
mask = mask.asType("bool").asType("float32");
if (mask.rank === ndim - 1) {
mask = tfc.expandDims(mask, -1);
}
mask = tfc.transpose(mask, axes);
}
if (goBackwards) {
inputs = tfc.reverse(inputs, 0);
if (mask != null) {
mask = tfc.reverse(mask, 0);
}
}
var perStepOutputs = [];
var lastOutput;
var states = initialStates;
var timeSteps = inputs.shape[0];
var perStepInputs = tfc.unstack(inputs);
var perStepMasks;
if (mask != null) {
perStepMasks = tfc.unstack(mask);
}
var _loop_1 = function(t2) {
var currentInput = perStepInputs[t2];
var stepOutputs = tfc.tidy(function() {
return stepFunction(currentInput, states);
});
if (mask == null) {
lastOutput = stepOutputs[0];
states = stepOutputs[1];
} else {
var maskedOutputs = tfc.tidy(function() {
var stepMask = perStepMasks[t2];
var negStepMask = tfc.onesLike(stepMask).sub(stepMask);
var output = stepOutputs[0].mul(stepMask).add(states[0].mul(negStepMask));
var newStates = states.map(function(state, i) {
return stepOutputs[1][i].mul(stepMask).add(state.mul(negStepMask));
});
return {output, newStates};
});
lastOutput = maskedOutputs.output;
states = maskedOutputs.newStates;
}
if (needPerStepOutputs) {
perStepOutputs.push(lastOutput);
}
};
for (var t = 0; t < timeSteps; ++t) {
_loop_1(t);
}
var outputs;
if (needPerStepOutputs) {
var axis = 1;
outputs = tfc.stack(perStepOutputs, axis);
}
return [lastOutput, outputs, states];
});
}
var RNN = function(_super) {
__extends(RNN2, _super);
function RNN2(args) {
var _this = _super.call(this, args) || this;
var cell;
if (args.cell == null) {
throw new ValueError("cell property is missing for the constructor of RNN.");
} else if (Array.isArray(args.cell)) {
cell = new StackedRNNCells({cells: args.cell});
} else {
cell = args.cell;
}
if (cell.stateSize == null) {
throw new ValueError("The RNN cell should have an attribute `stateSize` (tuple of integers, one integer per RNN state).");
}
_this.cell = cell;
_this.returnSequences = args.returnSequences == null ? false : args.returnSequences;
_this.returnState = args.returnState == null ? false : args.returnState;
_this.goBackwards = args.goBackwards == null ? false : args.goBackwards;
_this._stateful = args.stateful == null ? false : args.stateful;
_this.unroll = args.unroll == null ? false : args.unroll;
_this.supportsMasking = true;
_this.inputSpec = [new InputSpec({ndim: 3})];
_this.stateSpec = null;
_this.states_ = null;
_this.numConstants = null;
_this.keptStates = [];
return _this;
}
RNN2.prototype.getStates = function() {
if (this.states_ == null) {
var numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
return range(0, numStates).map(function(x) {
return null;
});
} else {
return this.states_;
}
};
RNN2.prototype.setStates = function(states) {
this.states_ = states;
};
RNN2.prototype.computeOutputShape = function(inputShape) {
if (isArrayOfShapes(inputShape)) {
inputShape = inputShape[0];
}
inputShape = inputShape;
var stateSize = this.cell.stateSize;
if (!Array.isArray(stateSize)) {
stateSize = [stateSize];
}
var outputDim = stateSize[0];
var outputShape;
if (this.returnSequences) {
outputShape = [inputShape[0], inputShape[1], outputDim];
} else {
outputShape = [inputShape[0], outputDim];
}
if (this.returnState) {
var stateShape = [];
for (var _i = 0, stateSize_1 = stateSize; _i < stateSize_1.length; _i++) {
var dim = stateSize_1[_i];
stateShape.push([inputShape[0], dim]);
}
return [outputShape].concat(stateShape);
} else {
return outputShape;
}
};
RNN2.prototype.computeMask = function(inputs, mask) {
var _this = this;
return tfc.tidy(function() {
if (Array.isArray(mask)) {
mask = mask[0];
}
var outputMask = _this.returnSequences ? mask : null;
if (_this.returnState) {
var stateMask = _this.states.map(function(s) {
return null;
});
return [outputMask].concat(stateMask);
} else {
return outputMask;
}
});
};
Object.defineProperty(RNN2.prototype, "states", {
get: function() {
if (this.states_ == null) {
var numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
var output = [];
for (var i = 0; i < numStates; ++i) {
output.push(null);
}
return output;
} else {
return this.states_;
}
},
set: function(s) {
this.states_ = s;
},
enumerable: true,
configurable: true
});
RNN2.prototype.build = function(inputShape) {
if (this.numConstants != null) {
throw new NotImplementedError("Constants support is not implemented in RNN yet.");
}
if (isArrayOfShapes(inputShape)) {
inputShape = inputShape[0];
}
inputShape = inputShape;
var batchSize = this.stateful ? inputShape[0] : null;
var inputDim = inputShape.slice(2);
this.inputSpec[0] = new InputSpec({shape: [batchSize, null].concat(inputDim)});
var stepInputShape = [inputShape[0]].concat(inputShape.slice(2));
{
this.cell.build(stepInputShape);
}
var stateSize;
if (Array.isArray(this.cell.stateSize)) {
stateSize = this.cell.stateSize;
} else {
stateSize = [this.cell.stateSize];
}
if (this.stateSpec != null) {
if (!tfc.util.arraysEqual(this.stateSpec.map(function(spec) {
return spec.shape[spec.shape.length - 1];
}), stateSize)) {
throw new ValueError("An initialState was passed that is not compatible with " + ("cell.stateSize. Received stateSpec=" + this.stateSpec + "; ") + ("However cell.stateSize is " + this.cell.stateSize));
}
} else {
this.stateSpec = stateSize.map(function(dim) {
return new InputSpec({shape: [null, dim]});
});
}
if (this.stateful) {
this.resetStates();
}
};
RNN2.prototype.resetStates = function(states, training) {
var _this = this;
if (training === void 0) {
training = false;
}
tfc.tidy(function() {
if (!_this.stateful) {
throw new AttributeError("Cannot call resetStates() on an RNN Layer that is not stateful.");
}
var batchSize = _this.inputSpec[0].shape[0];
if (batchSize == null) {
throw new ValueError("If an RNN is stateful, it needs to know its batch size. Specify the batch size of your input tensors: \n- If using a Sequential model, specify the batch size by passing a `batchInputShape` option to your first layer.\n- If using the functional API, specify the batch size by passing a `batchShape` option to your Input layer.");
}
if (_this.states_ == null) {
if (Array.isArray(_this.cell.stateSize)) {
_this.states_ = _this.cell.stateSize.map(function(dim2) {
return tfc.zeros([batchSize, dim2]);
});
} else {
_this.states_ = [tfc.zeros([batchSize, _this.cell.stateSize])];
}
} else if (states == null) {
tfc.dispose(_this.states_);
if (_this.keptStates != null) {
tfc.dispose(_this.keptStates);
_this.keptStates = [];
}
if (Array.isArray(_this.cell.stateSize)) {
_this.states_ = _this.cell.stateSize.map(function(dim2) {
return tfc.zeros([batchSize, dim2]);
});
} else {
_this.states_[0] = tfc.zeros([batchSize, _this.cell.stateSize]);
}
} else {
if (!Array.isArray(states)) {
states = [states];
}
if (states.length !== _this.states_.length) {
throw new ValueError("Layer " + _this.name + " expects " + _this.states_.length + " state(s), " + ("but it received " + states.length + " state value(s). Input ") + ("received: " + states));
}
if (training === true) {
_this.keptStates.push(_this.states_.slice());
} else {
tfc.dispose(_this.states_);
}
for (var index = 0; index < _this.states_.length; ++index) {
var value = states[index];
var dim = Array.isArray(_this.cell.stateSize) ? _this.cell.stateSize[index] : _this.cell.stateSize;
var expectedShape = [batchSize, dim];
if (!tfc.util.arraysEqual(value.shape, expectedShape)) {
throw new ValueError("State " + index + " is incompatible with layer " + _this.name + ": " + ("expected shape=" + expectedShape + ", received shape=" + value.shape));
}
_this.states_[index] = value;
}
}
_this.states_ = _this.states_.map(function(state) {
return tfc.keep(state.clone());
});
});
};
RNN2.prototype.apply = function(inputs, kwargs) {
var initialState = kwargs == null ? null : kwargs["initialState"];
var constants = kwargs == null ? null : kwargs["constants"];
if (kwargs == null) {
kwargs = {};
}
var standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
inputs = standardized.inputs;
initialState = standardized.initialState;
constants = standardized.constants;
var additionalInputs = [];
var additionalSpecs = [];
if (initialState != null) {
kwargs["initialState"] = initialState;
additionalInputs = additionalInputs.concat(initialState);
this.stateSpec = [];
for (var _i = 0, initialState_1 = initialState; _i < initialState_1.length; _i++) {
var state = initialState_1[_i];
this.stateSpec.push(new InputSpec({shape: state.shape}));
}
additionalSpecs = additionalSpecs.concat(this.stateSpec);
}
if (constants != null) {
kwargs["constants"] = constants;
additionalInputs = additionalInputs.concat(constants);
this.numConstants = constants.length;
}
var isTensor = additionalInputs[0] instanceof SymbolicTensor;
if (isTensor) {
var fullInput = [inputs].concat(additionalInputs);
var fullInputSpec = this.inputSpec.concat(additionalSpecs);
var originalInputSpec = this.inputSpec;
this.inputSpec = fullInputSpec;
var output = _super.prototype.apply.call(this, fullInput, kwargs);
this.inputSpec = originalInputSpec;
return output;
} else {
return _super.prototype.apply.call(this, inputs, kwargs);
}
};
RNN2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
var mask = kwargs == null ? null : kwargs["mask"];
var training = kwargs == null ? null : kwargs["training"];
var initialState = kwargs == null ? null : kwargs["initialState"];
inputs = getExactlyOneTensor(inputs);
if (initialState == null) {
if (_this.stateful) {
initialState = _this.states_;
} else {
initialState = _this.getInitialState(inputs);
}
}
var numStates = Array.isArray(_this.cell.stateSize) ? _this.cell.stateSize.length : 1;
if (initialState.length !== numStates) {
throw new ValueError("RNN Layer has " + numStates + " state(s) but was passed " + (initialState.length + " initial state(s)."));
}
if (_this.unroll) {
console.warn("Ignoring unroll = true for RNN layer, due to imperative backend.");
}
var cellCallKwargs = {training};
var step = function(inputs2, states2) {
var outputs2 = _this.cell.call([inputs2].concat(states2), cellCallKwargs);
return [outputs2[0], outputs2.slice(1)];
};
var rnnOutputs = rnn(step, inputs, initialState, _this.goBackwards, mask, null, _this.unroll, _this.returnSequences);
var lastOutput = rnnOutputs[0];
var outputs = rnnOutputs[1];
var states = rnnOutputs[2];
if (_this.stateful) {
_this.resetStates(states, training);
}
var output = _this.returnSequences ? outputs : lastOutput;
if (_this.returnState) {
return [output].concat(states);
} else {
return output;
}
});
};
RNN2.prototype.getInitialState = function(inputs) {
var _this = this;
return tfc.tidy(function() {
var initialState = tfc.zeros(inputs.shape);
initialState = tfc.sum(initialState, [1, 2]);
initialState = expandDims(initialState);
if (Array.isArray(_this.cell.stateSize)) {
return _this.cell.stateSize.map(function(dim) {
return dim > 1 ? tile(initialState, [1, dim]) : initialState;
});
} else {
return _this.cell.stateSize > 1 ? [tile(initialState, [1, _this.cell.stateSize])] : [initialState];
}
});
};
Object.defineProperty(RNN2.prototype, "trainableWeights", {
get: function() {
if (!this.trainable) {
return [];
}
return this.cell.trainableWeights;
},
enumerable: true,
configurable: true
});
Object.defineProperty(RNN2.prototype, "nonTrainableWeights", {
get: function() {
if (!this.trainable) {
return this.cell.weights;
}
return this.cell.nonTrainableWeights;
},
enumerable: true,
configurable: true
});
RNN2.prototype.setFastWeightInitDuringBuild = function(value) {
_super.prototype.setFastWeightInitDuringBuild.call(this, value);
if (this.cell != null) {
this.cell.setFastWeightInitDuringBuild(value);
}
};
RNN2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var config = {
returnSequences: this.returnSequences,
returnState: this.returnState,
goBackwards: this.goBackwards,
stateful: this.stateful,
unroll: this.unroll
};
if (this.numConstants != null) {
config["numConstants"] = this.numConstants;
}
var cellConfig = this.cell.getConfig();
if (this.getClassName() === RNN2.className) {
config["cell"] = {
className: this.cell.getClassName(),
config: cellConfig
};
}
return __assign({}, cellConfig, baseConfig, config);
};
RNN2.fromConfig = function(cls, config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
var cellConfig = config["cell"];
var cell = deserialize(cellConfig, customObjects);
return new cls(Object.assign(config, {cell}));
};
RNN2.className = "RNN";
return RNN2;
}(Layer);
tfc.serialization.registerClass(RNN);
var RNNCell = function(_super) {
__extends(RNNCell2, _super);
function RNNCell2() {
return _super !== null && _super.apply(this, arguments) || this;
}
return RNNCell2;
}(Layer);
var SimpleRNNCell = function(_super) {
__extends(SimpleRNNCell2, _super);
function SimpleRNNCell2(args) {
var _this = _super.call(this, args) || this;
_this.DEFAULT_ACTIVATION = "tanh";
_this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal";
_this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal";
_this.DEFAULT_BIAS_INITIALIZER = "zeros";
_this.units = args.units;
assertPositiveInteger(_this.units, "units");
_this.activation = getActivation(args.activation == null ? _this.DEFAULT_ACTIVATION : args.activation);
_this.useBias = args.useBias == null ? true : args.useBias;
_this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER);
_this.recurrentInitializer = getInitializer(args.recurrentInitializer || _this.DEFAULT_RECURRENT_INITIALIZER);
_this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER);
_this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
_this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
_this.biasRegularizer = getRegularizer(args.biasRegularizer);
_this.kernelConstraint = getConstraint(args.kernelConstraint);
_this.recurrentConstraint = getConstraint(args.recurrentConstraint);
_this.biasConstraint = getConstraint(args.biasConstraint);
_this.dropout = min([1, max([0, args.dropout == null ? 0 : args.dropout])]);
_this.recurrentDropout = min([
1,
max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
]);
_this.stateSize = _this.units;
_this.dropoutMask = null;
_this.recurrentDropoutMask = null;
return _this;
}
SimpleRNNCell2.prototype.build = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
this.kernel = this.addWeight("kernel", [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight("recurrent_kernel", [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
if (this.useBias) {
this.bias = this.addWeight("bias", [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} else {
this.bias = null;
}
this.built = true;
};
SimpleRNNCell2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = inputs;
if (inputs.length !== 2) {
throw new ValueError("SimpleRNNCell expects 2 input Tensors, got " + inputs.length + ".");
}
var prevOutput = inputs[1];
inputs = inputs[0];
var training = kwargs["training"] == null ? false : kwargs["training"];
if (0 < _this.dropout && _this.dropout < 1 && _this.dropoutMask == null) {
_this.dropoutMask = generateDropoutMask({
ones: function() {
return tfc.onesLike(inputs);
},
rate: _this.dropout,
training
});
}
if (0 < _this.recurrentDropout && _this.recurrentDropout < 1 && _this.recurrentDropoutMask == null) {
_this.recurrentDropoutMask = generateDropoutMask({
ones: function() {
return tfc.onesLike(prevOutput);
},
rate: _this.recurrentDropout,
training
});
}
var h;
var dpMask = _this.dropoutMask;
var recDpMask = _this.recurrentDropoutMask;
if (dpMask != null) {
h = dot(tfc.mul(inputs, dpMask), _this.kernel.read());
} else {
h = dot(inputs, _this.kernel.read());
}
if (_this.bias != null) {
h = biasAdd(h, _this.bias.read());
}
if (recDpMask != null) {
prevOutput = tfc.mul(prevOutput, recDpMask);
}
var output = tfc.add(h, dot(prevOutput, _this.recurrentKernel.read()));
if (_this.activation != null) {
output = _this.activation.apply(output);
}
return [output, output];
});
};
SimpleRNNCell2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var config = {
units: this.units,
activation: serializeActivation(this.activation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
recurrentInitializer: serializeInitializer(this.recurrentInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
recurrentConstraint: serializeConstraint(this.recurrentConstraint),
biasConstraint: serializeConstraint(this.biasConstraint),
dropout: this.dropout,
recurrentDropout: this.recurrentDropout
};
return __assign({}, baseConfig, config);
};
SimpleRNNCell2.className = "SimpleRNNCell";
return SimpleRNNCell2;
}(RNNCell);
tfc.serialization.registerClass(SimpleRNNCell);
var SimpleRNN = function(_super) {
__extends(SimpleRNN2, _super);
function SimpleRNN2(args) {
var _this = this;
args.cell = new SimpleRNNCell(args);
_this = _super.call(this, args) || this;
return _this;
}
SimpleRNN2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
if (_this.cell.dropoutMask != null) {
tfc.dispose(_this.cell.dropoutMask);
_this.cell.dropoutMask = null;
}
if (_this.cell.recurrentDropoutMask != null) {
tfc.dispose(_this.cell.recurrentDropoutMask);
_this.cell.recurrentDropoutMask = null;
}
var mask = kwargs == null ? null : kwargs["mask"];
var training = kwargs == null ? null : kwargs["training"];
var initialState = kwargs == null ? null : kwargs["initialState"];
return _super.prototype.call.call(_this, inputs, {mask, training, initialState});
});
};
SimpleRNN2.fromConfig = function(cls, config) {
return new cls(config);
};
SimpleRNN2.className = "SimpleRNN";
return SimpleRNN2;
}(RNN);
tfc.serialization.registerClass(SimpleRNN);
var GRUCell = function(_super) {
__extends(GRUCell2, _super);
function GRUCell2(args) {
var _this = _super.call(this, args) || this;
_this.DEFAULT_ACTIVATION = "tanh";
_this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid";
_this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal";
_this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal";
_this.DEFAULT_BIAS_INITIALIZER = "zeros";
if (args.resetAfter) {
throw new ValueError("GRUCell does not support reset_after parameter set to true.");
}
_this.units = args.units;
assertPositiveInteger(_this.units, "units");
_this.activation = getActivation(args.activation === void 0 ? _this.DEFAULT_ACTIVATION : args.activation);
_this.recurrentActivation = getActivation(args.recurrentActivation === void 0 ? _this.DEFAULT_RECURRENT_ACTIVATION : args.recurrentActivation);
_this.useBias = args.useBias == null ? true : args.useBias;
_this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER);
_this.recurrentInitializer = getInitializer(args.recurrentInitializer || _this.DEFAULT_RECURRENT_INITIALIZER);
_this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER);
_this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
_this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
_this.biasRegularizer = getRegularizer(args.biasRegularizer);
_this.kernelConstraint = getConstraint(args.kernelConstraint);
_this.recurrentConstraint = getConstraint(args.recurrentConstraint);
_this.biasConstraint = getConstraint(args.biasConstraint);
_this.dropout = min([1, max([0, args.dropout == null ? 0 : args.dropout])]);
_this.recurrentDropout = min([
1,
max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
]);
_this.implementation = args.implementation;
_this.stateSize = _this.units;
_this.dropoutMask = null;
_this.recurrentDropoutMask = null;
return _this;
}
GRUCell2.prototype.build = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var inputDim = inputShape[inputShape.length - 1];
this.kernel = this.addWeight("kernel", [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight("recurrent_kernel", [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
if (this.useBias) {
this.bias = this.addWeight("bias", [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} else {
this.bias = null;
}
this.built = true;
};
GRUCell2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = inputs;
if (inputs.length !== 2) {
throw new ValueError("GRUCell expects 2 input Tensors (inputs, h, c), got " + (inputs.length + "."));
}
var training = kwargs["training"] == null ? false : kwargs["training"];
var hTMinus1 = inputs[1];
inputs = inputs[0];
if (0 < _this.dropout && _this.dropout < 1 && _this.dropoutMask == null) {
_this.dropoutMask = generateDropoutMask({
ones: function() {
return tfc.onesLike(inputs);
},
rate: _this.dropout,
training,
count: 3
});
}
if (0 < _this.recurrentDropout && _this.recurrentDropout < 1 && _this.recurrentDropoutMask == null) {
_this.recurrentDropoutMask = generateDropoutMask({
ones: function() {
return tfc.onesLike(hTMinus1);
},
rate: _this.recurrentDropout,
training,
count: 3
});
}
var dpMask = _this.dropoutMask;
var recDpMask = _this.recurrentDropoutMask;
var z;
var r;
var hh;
if (0 < _this.dropout && _this.dropout < 1) {
inputs = tfc.mul(inputs, dpMask[0]);
}
var matrixX = dot(inputs, _this.kernel.read());
if (_this.useBias) {
matrixX = biasAdd(matrixX, _this.bias.read());
}
if (0 < _this.recurrentDropout && _this.recurrentDropout < 1) {
hTMinus1 = tfc.mul(hTMinus1, recDpMask[0]);
}
var recurrentKernelValue = _this.recurrentKernel.read();
var _a = tfc.split(recurrentKernelValue, [2 * _this.units, _this.units], recurrentKernelValue.rank - 1), rk1 = _a[0], rk2 = _a[1];
var matrixInner = dot(hTMinus1, rk1);
var _b = tfc.split(matrixX, 3, matrixX.rank - 1), xZ = _b[0], xR = _b[1], xH = _b[2];
var _c = tfc.split(matrixInner, 2, matrixInner.rank - 1), recurrentZ = _c[0], recurrentR = _c[1];
z = _this.recurrentActivation.apply(tfc.add(xZ, recurrentZ));
r = _this.recurrentActivation.apply(tfc.add(xR, recurrentR));
var recurrentH = dot(tfc.mul(r, hTMinus1), rk2);
hh = _this.activation.apply(tfc.add(xH, recurrentH));
var h = tfc.add(tfc.mul(z, hTMinus1), tfc.mul(tfc.add(1, tfc.neg(z)), hh));
return [h, h];
});
};
GRUCell2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var config = {
units: this.units,
activation: serializeActivation(this.activation),
recurrentActivation: serializeActivation(this.recurrentActivation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
recurrentInitializer: serializeInitializer(this.recurrentInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
recurrentConstraint: serializeConstraint(this.recurrentConstraint),
biasConstraint: serializeConstraint(this.biasConstraint),
dropout: this.dropout,
recurrentDropout: this.recurrentDropout,
implementation: this.implementation,
resetAfter: false
};
return __assign({}, baseConfig, config);
};
GRUCell2.className = "GRUCell";
return GRUCell2;
}(RNNCell);
tfc.serialization.registerClass(GRUCell);
var GRU = function(_super) {
__extends(GRU2, _super);
function GRU2(args) {
var _this = this;
if (args.implementation === 0) {
console.warn("`implementation=0` has been deprecated, and now defaults to `implementation=1`. Please update your layer call.");
}
args.cell = new GRUCell(args);
_this = _super.call(this, args) || this;
return _this;
}
GRU2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
if (_this.cell.dropoutMask != null) {
tfc.dispose(_this.cell.dropoutMask);
_this.cell.dropoutMask = null;
}
if (_this.cell.recurrentDropoutMask != null) {
tfc.dispose(_this.cell.recurrentDropoutMask);
_this.cell.recurrentDropoutMask = null;
}
var mask = kwargs == null ? null : kwargs["mask"];
var training = kwargs == null ? null : kwargs["training"];
var initialState = kwargs == null ? null : kwargs["initialState"];
return _super.prototype.call.call(_this, inputs, {mask, training, initialState});
});
};
GRU2.fromConfig = function(cls, config) {
if (config["implmentation"] === 0) {
config["implementation"] = 1;
}
return new cls(config);
};
GRU2.className = "GRU";
return GRU2;
}(RNN);
tfc.serialization.registerClass(GRU);
var LSTMCell = function(_super) {
__extends(LSTMCell2, _super);
function LSTMCell2(args) {
var _this = _super.call(this, args) || this;
_this.DEFAULT_ACTIVATION = "tanh";
_this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid";
_this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal";
_this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal";
_this.DEFAULT_BIAS_INITIALIZER = "zeros";
_this.units = args.units;
assertPositiveInteger(_this.units, "units");
_this.activation = getActivation(args.activation === void 0 ? _this.DEFAULT_ACTIVATION : args.activation);
_this.recurrentActivation = getActivation(args.recurrentActivation === void 0 ? _this.DEFAULT_RECURRENT_ACTIVATION : args.recurrentActivation);
_this.useBias = args.useBias == null ? true : args.useBias;
_this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER);
_this.recurrentInitializer = getInitializer(args.recurrentInitializer || _this.DEFAULT_RECURRENT_INITIALIZER);
_this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER);
_this.unitForgetBias = args.unitForgetBias;
_this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
_this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
_this.biasRegularizer = getRegularizer(args.biasRegularizer);
_this.kernelConstraint = getConstraint(args.kernelConstraint);
_this.recurrentConstraint = getConstraint(args.recurrentConstraint);
_this.biasConstraint = getConstraint(args.biasConstraint);
_this.dropout = min([1, max([0, args.dropout == null ? 0 : args.dropout])]);
_this.recurrentDropout = min([
1,
max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])
]);
_this.implementation = args.implementation;
_this.stateSize = [_this.units, _this.units];
_this.dropoutMask = null;
_this.recurrentDropoutMask = null;
return _this;
}
LSTMCell2.prototype.build = function(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
var inputDim = inputShape[inputShape.length - 1];
this.kernel = this.addWeight("kernel", [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight("recurrent_kernel", [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
var biasInitializer;
if (this.useBias) {
if (this.unitForgetBias) {
var capturedBiasInit_1 = this.biasInitializer;
var capturedUnits_1 = this.units;
biasInitializer = new (_a = function(_super2) {
__extends(CustomInit, _super2);
function CustomInit() {
return _super2 !== null && _super2.apply(this, arguments) || this;
}
CustomInit.prototype.apply = function(shape, dtype) {
var bI = capturedBiasInit_1.apply([capturedUnits_1]);
var bF = new Ones().apply([capturedUnits_1]);
var bCAndH = capturedBiasInit_1.apply([capturedUnits_1 * 2]);
return concatAlongFirstAxis(concatAlongFirstAxis(bI, bF), bCAndH);
};
return CustomInit;
}(Initializer), _a.className = "CustomInit", _a)();
} else {
biasInitializer = this.biasInitializer;
}
this.bias = this.addWeight("bias", [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} else {
this.bias = null;
}
this.built = true;
};
LSTMCell2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
var training = kwargs["training"] == null ? false : kwargs["training"];
inputs = inputs;
if (inputs.length !== 3) {
throw new ValueError("LSTMCell expects 3 input Tensors (inputs, h, c), got " + (inputs.length + "."));
}
var hTMinus1 = inputs[1];
var cTMinus1 = inputs[2];
inputs = inputs[0];
if (0 < _this.dropout && _this.dropout < 1 && _this.dropoutMask == null) {
_this.dropoutMask = generateDropoutMask({
ones: function() {
return tfc.onesLike(inputs);
},
rate: _this.dropout,
training,
count: 4
});
}
if (0 < _this.recurrentDropout && _this.recurrentDropout < 1 && _this.recurrentDropoutMask == null) {
_this.recurrentDropoutMask = generateDropoutMask({
ones: function() {
return tfc.onesLike(hTMinus1);
},
rate: _this.recurrentDropout,
training,
count: 4
});
}
var dpMask = _this.dropoutMask;
var recDpMask = _this.recurrentDropoutMask;
var i;
var f;
var c;
var o;
if (0 < _this.dropout && _this.dropout < 1) {
inputs = tfc.mul(inputs, dpMask[0]);
}
var z = dot(inputs, _this.kernel.read());
if (0 < _this.recurrentDropout && _this.recurrentDropout < 1) {
hTMinus1 = tfc.mul(hTMinus1, recDpMask[0]);
}
z = tfc.add(z, dot(hTMinus1, _this.recurrentKernel.read()));
if (_this.useBias) {
z = biasAdd(z, _this.bias.read());
}
var _a = tfc.split(z, 4, z.rank - 1), z0 = _a[0], z1 = _a[1], z2 = _a[2], z3 = _a[3];
i = _this.recurrentActivation.apply(z0);
f = _this.recurrentActivation.apply(z1);
c = tfc.add(tfc.mul(f, cTMinus1), tfc.mul(i, _this.activation.apply(z2)));
o = _this.recurrentActivation.apply(z3);
var h = tfc.mul(o, _this.activation.apply(c));
return [h, h, c];
});
};
LSTMCell2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var config = {
units: this.units,
activation: serializeActivation(this.activation),
recurrentActivation: serializeActivation(this.recurrentActivation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
recurrentInitializer: serializeInitializer(this.recurrentInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
unitForgetBias: this.unitForgetBias,
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
recurrentConstraint: serializeConstraint(this.recurrentConstraint),
biasConstraint: serializeConstraint(this.biasConstraint),
dropout: this.dropout,
recurrentDropout: this.recurrentDropout,
implementation: this.implementation
};
return __assign({}, baseConfig, config);
};
LSTMCell2.className = "LSTMCell";
return LSTMCell2;
}(RNNCell);
tfc.serialization.registerClass(LSTMCell);
var LSTM = function(_super) {
__extends(LSTM2, _super);
function LSTM2(args) {
var _this = this;
if (args.implementation === 0) {
console.warn("`implementation=0` has been deprecated, and now defaults to `implementation=1`. Please update your layer call.");
}
args.cell = new LSTMCell(args);
_this = _super.call(this, args) || this;
return _this;
}
LSTM2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
if (_this.cell.dropoutMask != null) {
tfc.dispose(_this.cell.dropoutMask);
_this.cell.dropoutMask = null;
}
if (_this.cell.recurrentDropoutMask != null) {
tfc.dispose(_this.cell.recurrentDropoutMask);
_this.cell.recurrentDropoutMask = null;
}
var mask = kwargs == null ? null : kwargs["mask"];
var training = kwargs == null ? null : kwargs["training"];
var initialState = kwargs == null ? null : kwargs["initialState"];
return _super.prototype.call.call(_this, inputs, {mask, training, initialState});
});
};
LSTM2.fromConfig = function(cls, config) {
if (config["implmentation"] === 0) {
config["implementation"] = 1;
}
return new cls(config);
};
LSTM2.className = "LSTM";
return LSTM2;
}(RNN);
tfc.serialization.registerClass(LSTM);
var StackedRNNCells = function(_super) {
__extends(StackedRNNCells2, _super);
function StackedRNNCells2(args) {
var _this = _super.call(this, args) || this;
_this.cells = args.cells;
return _this;
}
Object.defineProperty(StackedRNNCells2.prototype, "stateSize", {
get: function() {
var stateSize = [];
for (var _i = 0, _a = this.cells.slice().reverse(); _i < _a.length; _i++) {
var cell = _a[_i];
if (Array.isArray(cell.stateSize)) {
stateSize.push.apply(stateSize, cell.stateSize);
} else {
stateSize.push(cell.stateSize);
}
}
return stateSize;
},
enumerable: true,
configurable: true
});
StackedRNNCells2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = inputs;
var states = inputs.slice(1);
var nestedStates = [];
for (var _i = 0, _a = _this.cells.slice().reverse(); _i < _a.length; _i++) {
var cell = _a[_i];
if (Array.isArray(cell.stateSize)) {
nestedStates.push(states.splice(0, cell.stateSize.length));
} else {
nestedStates.push(states.splice(0, 1));
}
}
nestedStates.reverse();
var newNestedStates = [];
var callInputs;
for (var i = 0; i < _this.cells.length; ++i) {
var cell = _this.cells[i];
states = nestedStates[i];
if (i === 0) {
callInputs = [inputs[0]].concat(states);
} else {
callInputs = [callInputs[0]].concat(states);
}
callInputs = cell.call(callInputs, kwargs);
newNestedStates.push(callInputs.slice(1));
}
states = [];
for (var _b = 0, _c = newNestedStates.slice().reverse(); _b < _c.length; _b++) {
var cellStates = _c[_b];
states.push.apply(states, cellStates);
}
return [callInputs[0]].concat(states);
});
};
StackedRNNCells2.prototype.build = function(inputShape) {
if (isArrayOfShapes(inputShape)) {
inputShape = inputShape[0];
}
inputShape = inputShape;
var outputDim;
this.cells.forEach(function(cell, i) {
nameScope("RNNCell_" + i, function() {
cell.build(inputShape);
if (Array.isArray(cell.stateSize)) {
outputDim = cell.stateSize[0];
} else {
outputDim = cell.stateSize;
}
inputShape = [inputShape[0], outputDim];
});
});
this.built = true;
};
StackedRNNCells2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var getCellConfig = function(cell) {
return {
className: cell.getClassName(),
config: cell.getConfig()
};
};
var cellConfigs = this.cells.map(getCellConfig);
var config = {cells: cellConfigs};
return __assign({}, baseConfig, config);
};
StackedRNNCells2.fromConfig = function(cls, config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
var cells = [];
for (var _i = 0, _a = config["cells"]; _i < _a.length; _i++) {
var cellConfig = _a[_i];
cells.push(deserialize(cellConfig, customObjects));
}
return new cls({cells});
};
Object.defineProperty(StackedRNNCells2.prototype, "trainableWeights", {
get: function() {
if (!this.trainable) {
return [];
}
var weights = [];
for (var _i = 0, _a = this.cells; _i < _a.length; _i++) {
var cell = _a[_i];
weights.push.apply(weights, cell.trainableWeights);
}
return weights;
},
enumerable: true,
configurable: true
});
Object.defineProperty(StackedRNNCells2.prototype, "nonTrainableWeights", {
get: function() {
var weights = [];
for (var _i = 0, _a = this.cells; _i < _a.length; _i++) {
var cell = _a[_i];
weights.push.apply(weights, cell.nonTrainableWeights);
}
if (!this.trainable) {
var trainableWeights = [];
for (var _b = 0, _c = this.cells; _b < _c.length; _b++) {
var cell = _c[_b];
trainableWeights.push.apply(trainableWeights, cell.trainableWeights);
}
return trainableWeights.concat(weights);
}
return weights;
},
enumerable: true,
configurable: true
});
StackedRNNCells2.prototype.getWeights = function() {
var weights = [];
for (var _i = 0, _a = this.cells; _i < _a.length; _i++) {
var cell = _a[_i];
weights.push.apply(weights, cell.weights);
}
return batchGetValue(weights);
};
StackedRNNCells2.prototype.setWeights = function(weights) {
var tuples = [];
for (var _i = 0, _a = this.cells; _i < _a.length; _i++) {
var cell = _a[_i];
var numParams = cell.weights.length;
var inputWeights = weights.splice(numParams);
for (var i = 0; i < cell.weights.length; ++i) {
tuples.push([cell.weights[i], inputWeights[i]]);
}
}
batchSetValue(tuples);
};
StackedRNNCells2.className = "StackedRNNCells";
return StackedRNNCells2;
}(RNNCell);
tfc.serialization.registerClass(StackedRNNCells);
function generateDropoutMask(args) {
var ones2 = args.ones, rate = args.rate, _a = args.training, training = _a === void 0 ? false : _a, _b = args.count, count2 = _b === void 0 ? 1 : _b;
var droppedInputs = function() {
return dropout(ones2(), rate);
};
var createMask = function() {
return inTrainPhase(droppedInputs, ones2, training);
};
if (!count2 || count2 <= 1) {
return tfc.keep(createMask().clone());
}
var masks = Array(count2).fill(void 0).map(createMask);
return masks.map(function(m) {
return tfc.keep(m.clone());
});
}
/**
* @license
* Copyright 2020 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var ConvRNN2DCell = function(_super) {
__extends(ConvRNN2DCell2, _super);
function ConvRNN2DCell2() {
return _super !== null && _super.apply(this, arguments) || this;
}
return ConvRNN2DCell2;
}(RNNCell);
var ConvRNN2D = function(_super) {
__extends(ConvRNN2D2, _super);
function ConvRNN2D2(args) {
var _this = this;
if (args.unroll) {
throw new NotImplementedError("Unrolling is not possible with convolutional RNNs.");
}
if (Array.isArray(args.cell)) {
throw new NotImplementedError("It is not possible at the moment to stack convolutional cells.");
}
_this = _super.call(this, args) || this;
_this.inputSpec = [new InputSpec({ndim: 5})];
return _this;
}
ConvRNN2D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
if (_this.cell.dropoutMask != null) {
tfc.dispose(_this.cell.dropoutMask);
_this.cell.dropoutMask = null;
}
if (_this.cell.recurrentDropoutMask != null) {
tfc.dispose(_this.cell.recurrentDropoutMask);
_this.cell.recurrentDropoutMask = null;
}
if (kwargs && kwargs["constants"]) {
throw new ValueError("ConvRNN2D cell does not support constants");
}
var mask = kwargs == null ? null : kwargs["mask"];
var training = kwargs == null ? null : kwargs["training"];
var initialState = kwargs == null ? null : kwargs["initialState"];
return _super.prototype.call.call(_this, inputs, {mask, training, initialState});
});
};
ConvRNN2D2.prototype.computeOutputShape = function(inputShape) {
var outShape = this.computeSingleOutputShape(inputShape);
if (!this.returnSequences) {
outShape = [outShape[0]].concat(outShape.slice(2));
}
if (this.returnState) {
outShape = [outShape].concat(Array(2).fill([inputShape[0]].concat(outShape.slice(-3))));
}
return outShape;
};
ConvRNN2D2.prototype.getInitialState = function(inputs) {
var _this = this;
return tfc.tidy(function() {
var stateSize = _this.cell.stateSize;
var inputShape = inputs.shape;
var outputShape = _this.computeSingleOutputShape(inputShape);
var stateShape = [outputShape[0]].concat(outputShape.slice(2));
var initialState = tfc.zeros(stateShape);
if (Array.isArray(stateSize)) {
return Array(stateSize.length).fill(initialState);
}
return [initialState];
});
};
ConvRNN2D2.prototype.resetStates = function(states, training) {
var _this = this;
if (training === void 0) {
training = false;
}
tfc.tidy(function() {
if (!_this.stateful) {
throw new AttributeError("Cannot call resetStates() on an RNN Layer that is not stateful.");
}
var inputShape = _this.inputSpec[0].shape;
var outputShape = _this.computeSingleOutputShape(inputShape);
var stateShape = [outputShape[0]].concat(outputShape.slice(2));
var batchSize = inputShape[0];
if (batchSize == null) {
throw new ValueError("If an RNN is stateful, it needs to know its batch size. Specify the batch size of your input tensors: \n- If using a Sequential model, specify the batch size by passing a `batchInputShape` option to your first layer.\n- If using the functional API, specify the batch size by passing a `batchShape` option to your Input layer.");
}
if (_this.getStates() == null) {
if (Array.isArray(_this.cell.stateSize)) {
_this.states_ = _this.cell.stateSize.map(function() {
return tfc.zeros(stateShape);
});
} else {
_this.states_ = [tfc.zeros(stateShape)];
}
} else if (states == null) {
tfc.dispose(_this.states_);
if (_this.keptStates != null) {
tfc.dispose(_this.keptStates);
_this.keptStates = [];
}
if (Array.isArray(_this.cell.stateSize)) {
_this.states_ = _this.cell.stateSize.map(function() {
return tfc.zeros(stateShape);
});
} else {
_this.states_[0] = tfc.zeros(stateShape);
}
} else {
if (!Array.isArray(states)) {
states = [states];
}
if (states.length !== _this.states_.length) {
throw new ValueError("Layer " + _this.name + " expects " + _this.states_.length + " state(s), " + ("but it received " + states.length + " state value(s). Input ") + ("received: " + states));
}
if (training) {
_this.keptStates.push(_this.states_.slice());
} else {
tfc.dispose(_this.states_);
}
for (var index = 0; index < _this.states_.length; ++index) {
var value = states[index];
var expectedShape = stateShape;
if (!tfc.util.arraysEqual(value.shape, expectedShape)) {
throw new ValueError("State " + index + " is incompatible with layer " + _this.name + ": " + ("expected shape=" + expectedShape + ", received shape=" + value.shape));
}
_this.states_[index] = value;
}
}
_this.states_ = _this.states_.map(function(state) {
return tfc.keep(state.clone());
});
});
};
ConvRNN2D2.prototype.computeSingleOutputShape = function(inputShape) {
var _a = this.cell, dataFormat = _a.dataFormat, filters = _a.filters, kernelSize = _a.kernelSize, padding = _a.padding, strides = _a.strides, dilationRate = _a.dilationRate;
var isChannelsFirst = dataFormat === "channelsFirst";
var h = inputShape[isChannelsFirst ? 3 : 2];
var w = inputShape[isChannelsFirst ? 4 : 3];
var hOut = convOutputLength(h, kernelSize[0], padding, strides[0], dilationRate[0]);
var wOut = convOutputLength(w, kernelSize[1], padding, strides[1], dilationRate[1]);
var outShape = inputShape.slice(0, 2).concat(isChannelsFirst ? [filters, hOut, wOut] : [hOut, wOut, filters]);
return outShape;
};
ConvRNN2D2.className = "ConvRNN2D";
return ConvRNN2D2;
}(RNN);
var ConvLSTM2DCell = function(_super) {
__extends(ConvLSTM2DCell2, _super);
function ConvLSTM2DCell2(args) {
var _this = this;
var filters = args.filters, kernelSize = args.kernelSize, strides = args.strides, padding = args.padding, dataFormat = args.dataFormat, dilationRate = args.dilationRate;
_this = _super.call(this, __assign({}, args, {units: filters})) || this;
_this.filters = filters;
assertPositiveInteger(_this.filters, "filters");
_this.kernelSize = normalizeArray(kernelSize, 2, "kernelSize");
_this.kernelSize.forEach(function(size) {
return assertPositiveInteger(size, "kernelSize");
});
_this.strides = normalizeArray(strides || 1, 2, "strides");
_this.strides.forEach(function(stride) {
return assertPositiveInteger(stride, "strides");
});
_this.padding = padding || "valid";
checkPaddingMode(_this.padding);
_this.dataFormat = dataFormat || "channelsLast";
checkDataFormat(_this.dataFormat);
_this.dilationRate = normalizeArray(dilationRate || 1, 2, "dilationRate");
_this.dilationRate.forEach(function(rate) {
return assertPositiveInteger(rate, "dilationRate");
});
return _this;
}
ConvLSTM2DCell2.prototype.build = function(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
var channelAxis = this.dataFormat === "channelsFirst" ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null) {
throw new ValueError("The channel dimension of the input should be defined. " + ("Found " + inputShape[channelAxis]));
}
var inputDim = inputShape[channelAxis];
var numOfKernels = 4;
var kernelShape = this.kernelSize.concat([inputDim, this.filters * numOfKernels]);
this.kernel = this.addWeight("kernel", kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
var recurrentKernelShape = this.kernelSize.concat([this.filters, this.filters * numOfKernels]);
this.recurrentKernel = this.addWeight("recurrent_kernel", recurrentKernelShape, null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
if (this.useBias) {
var biasInitializer = void 0;
if (this.unitForgetBias) {
var init_1 = this.biasInitializer;
var filters_1 = this.filters;
biasInitializer = new (_a = function(_super2) {
__extends(CustomInit, _super2);
function CustomInit() {
return _super2 !== null && _super2.apply(this, arguments) || this;
}
CustomInit.prototype.apply = function(shape, dtype) {
var biasI = init_1.apply([filters_1]);
var biasF = tfc.ones([filters_1]);
var biasCAndO = init_1.apply([filters_1 * 2]);
return concatenate([biasI, biasF, biasCAndO]);
};
return CustomInit;
}(Initializer), _a.className = "CustomInit", _a)();
} else {
biasInitializer = this.biasInitializer;
}
this.bias = this.addWeight("bias", [this.filters * numOfKernels], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
this.built = true;
};
ConvLSTM2DCell2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
if (inputs.length !== 3) {
throw new ValueError("ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got " + (inputs.length + "."));
}
var training = kwargs["training"] || false;
var x = inputs[0];
var hTMinus1 = inputs[1];
var cTMinus1 = inputs[2];
var numOfKernels = 4;
if (0 < _this.dropout && _this.dropout < 1 && _this.dropoutMask == null) {
_this.dropoutMask = generateDropoutMask({
ones: function() {
return tfc.onesLike(x);
},
rate: _this.dropout,
training,
count: numOfKernels
});
}
var dropoutMask = _this.dropoutMask;
var applyDropout = function(x2, mask, index) {
if (!mask || !mask[index]) {
return x2;
}
return tfc.mul(mask[index], x2);
};
var xI = applyDropout(x, dropoutMask, 0);
var xF = applyDropout(x, dropoutMask, 1);
var xC = applyDropout(x, dropoutMask, 2);
var xO = applyDropout(x, dropoutMask, 3);
if (0 < _this.recurrentDropout && _this.recurrentDropout < 1 && _this.recurrentDropoutMask == null) {
_this.recurrentDropoutMask = generateDropoutMask({
ones: function() {
return tfc.onesLike(hTMinus1);
},
rate: _this.recurrentDropout,
training,
count: numOfKernels
});
}
var recDropoutMask = _this.recurrentDropoutMask;
var hI = applyDropout(hTMinus1, recDropoutMask, 0);
var hF = applyDropout(hTMinus1, recDropoutMask, 1);
var hC = applyDropout(hTMinus1, recDropoutMask, 2);
var hO = applyDropout(hTMinus1, recDropoutMask, 3);
var kernelChannelAxis = 3;
var _a = tfc.split(_this.kernel.read(), numOfKernels, kernelChannelAxis), kernelI = _a[0], kernelF = _a[1], kernelC = _a[2], kernelO = _a[3];
var _b = _this.useBias ? tfc.split(_this.bias.read(), numOfKernels) : [null, null, null, null], biasI = _b[0], biasF = _b[1], biasC = _b[2], biasO = _b[3];
xI = _this.inputConv(xI, kernelI, biasI, _this.padding);
xF = _this.inputConv(xF, kernelF, biasF, _this.padding);
xC = _this.inputConv(xC, kernelC, biasC, _this.padding);
xO = _this.inputConv(xO, kernelO, biasO, _this.padding);
var _c = tfc.split(_this.recurrentKernel.read(), numOfKernels, kernelChannelAxis), recKernelI = _c[0], recKernelF = _c[1], recKernelC = _c[2], recKernelO = _c[3];
hI = _this.recurrentConv(hI, recKernelI);
hF = _this.recurrentConv(hF, recKernelF);
hC = _this.recurrentConv(hC, recKernelC);
hO = _this.recurrentConv(hO, recKernelO);
var i = _this.recurrentActivation.apply(tfc.add(xI, hI));
var f = _this.recurrentActivation.apply(tfc.add(xF, hF));
var c = tfc.add(tfc.mul(f, cTMinus1), tfc.mul(i, _this.activation.apply(tfc.add(xC, hC))));
var h = tfc.mul(_this.recurrentActivation.apply(tfc.add(xO, hO)), _this.activation.apply(c));
return [h, h, c];
});
};
ConvLSTM2DCell2.prototype.getConfig = function() {
var _a = _super.prototype.getConfig.call(this), _ = _a["units"], baseConfig = __rest(_a, ["units"]);
var config = {
filters: this.filters,
kernelSize: this.kernelSize,
padding: this.padding,
dataFormat: this.dataFormat,
dilationRate: this.dilationRate,
strides: this.strides
};
return __assign({}, baseConfig, config);
};
ConvLSTM2DCell2.prototype.inputConv = function(x, w, b, padding) {
var out = tfc.conv2d(x, w, this.strides, padding || "valid", this.dataFormat === "channelsFirst" ? "NCHW" : "NHWC", this.dilationRate);
if (b) {
return biasAdd(out, b, this.dataFormat);
}
return out;
};
ConvLSTM2DCell2.prototype.recurrentConv = function(x, w) {
var strides = 1;
return tfc.conv2d(x, w, strides, "same", this.dataFormat === "channelsFirst" ? "NCHW" : "NHWC");
};
ConvLSTM2DCell2.className = "ConvLSTM2DCell";
return ConvLSTM2DCell2;
}(LSTMCell);
tfc.serialization.registerClass(ConvLSTM2DCell);
var ConvLSTM2D = function(_super) {
__extends(ConvLSTM2D2, _super);
function ConvLSTM2D2(args) {
var _this = this;
var cell = new ConvLSTM2DCell(args);
_this = _super.call(this, __assign({}, args, {cell})) || this;
return _this;
}
ConvLSTM2D2.fromConfig = function(cls, config) {
return new cls(config);
};
ConvLSTM2D2.className = "ConvLSTM2D";
return ConvLSTM2D2;
}(ConvRNN2D);
tfc.serialization.registerClass(ConvLSTM2D);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var Dropout = function(_super) {
__extends(Dropout2, _super);
function Dropout2(args) {
var _this = _super.call(this, args) || this;
_this.rate = Math.max(Math.min(args.rate, 1), 0);
_this.noiseShape = args.noiseShape;
_this.seed = args.seed;
_this.supportsMasking = true;
return _this;
}
Dropout2.prototype.getNoiseShape = function(input2) {
if (this.noiseShape == null) {
return this.noiseShape;
}
var inputShape = input2.shape;
var noiseShape = [];
for (var i = 0; i < this.noiseShape.length; ++i) {
noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]);
}
return noiseShape;
};
Dropout2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
if (0 < _this.rate && _this.rate < 1) {
var training = kwargs["training"] == null ? false : kwargs["training"];
var noiseShape_1 = _this.getNoiseShape(input2);
var output = inTrainPhase(function() {
return dropout(input2, _this.rate, noiseShape_1, _this.seed);
}, function() {
return input2;
}, training);
return output;
}
return inputs;
});
};
Dropout2.prototype.getConfig = function() {
var config = {
rate: this.rate,
noiseShape: this.noiseShape,
seed: this.seed
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Dropout2.prototype.dispose = function() {
return _super.prototype.dispose.call(this);
};
Dropout2.className = "Dropout";
return Dropout2;
}(Layer);
tfc.serialization.registerClass(Dropout);
var SpatialDropout1D = function(_super) {
__extends(SpatialDropout1D2, _super);
function SpatialDropout1D2(args) {
var _this = _super.call(this, args) || this;
_this.inputSpec = [{ndim: 3}];
return _this;
}
SpatialDropout1D2.prototype.getNoiseShape = function(input2) {
var inputShape = input2.shape;
return [inputShape[0], 1, inputShape[2]];
};
SpatialDropout1D2.className = "SpatialDropout1D";
return SpatialDropout1D2;
}(Dropout);
tfc.serialization.registerClass(SpatialDropout1D);
var Dense = function(_super) {
__extends(Dense2, _super);
function Dense2(args) {
var _this = _super.call(this, args) || this;
_this.activation = null;
_this.useBias = true;
_this.kernel = null;
_this.bias = null;
_this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal";
_this.DEFAULT_BIAS_INITIALIZER = "zeros";
if (args.batchInputShape == null && args.inputShape == null && args.inputDim != null) {
var batchSize = null;
if (args.batchSize != null) {
batchSize = args.batchSize;
}
_this.batchInputShape = [batchSize, args.inputDim];
}
_this.units = args.units;
assertPositiveInteger(_this.units, "units");
_this.activation = getActivation(args.activation);
if (args.useBias != null) {
_this.useBias = args.useBias;
}
_this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER);
_this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER);
_this.kernelConstraint = getConstraint(args.kernelConstraint);
_this.biasConstraint = getConstraint(args.biasConstraint);
_this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
_this.biasRegularizer = getRegularizer(args.biasRegularizer);
_this.activityRegularizer = getRegularizer(args.activityRegularizer);
_this.supportsMasking = true;
_this.inputSpec = [{minNDim: 2}];
return _this;
}
Dense2.prototype.build = function(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
var inputLastDim = inputShape[inputShape.length - 1];
if (this.kernel == null) {
this.kernel = this.addWeight("kernel", [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
if (this.useBias) {
this.bias = this.addWeight("bias", [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
}
this.inputSpec = [{minNDim: 2, axes: (_a = {}, _a[-1] = inputLastDim, _a)}];
this.built = true;
};
Dense2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var outputShape = inputShape.slice();
outputShape[outputShape.length - 1] = this.units;
return outputShape;
};
Dense2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
var fusedActivationName = mapActivationToFusedKernel(_this.activation.getClassName());
var output;
if (fusedActivationName != null) {
output = dot(input2, _this.kernel.read(), fusedActivationName, _this.bias ? _this.bias.read() : null);
} else {
output = dot(input2, _this.kernel.read());
if (_this.bias != null) {
output = biasAdd(output, _this.bias.read());
}
if (_this.activation != null) {
output = _this.activation.apply(output);
}
}
return output;
});
};
Dense2.prototype.getConfig = function() {
var config = {
units: this.units,
activation: serializeActivation(this.activation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
biasConstraint: serializeConstraint(this.biasConstraint)
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Dense2.className = "Dense";
return Dense2;
}(Layer);
tfc.serialization.registerClass(Dense);
var Flatten = function(_super) {
__extends(Flatten2, _super);
function Flatten2(args) {
var _this = this;
args = args || {};
_this = _super.call(this, args) || this;
_this.inputSpec = [{minNDim: 3}];
_this.dataFormat = args.dataFormat;
return _this;
}
Flatten2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
for (var _i = 0, _a = inputShape.slice(1); _i < _a.length; _i++) {
var dim = _a[_i];
if (dim == null) {
throw new ValueError('The shape of the input to "Flatten" is not fully defined ' + ("(got " + inputShape.slice(1) + "). Make sure to pass a complete ") + '"input_shape" or "batch_input_shape" argument to the first layer in your model.');
}
}
return [inputShape[0], arrayProd(inputShape, 1)];
};
Flatten2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
if (_this.dataFormat === "channelsFirst" && input2.rank > 1) {
var permutation = [0];
for (var i = 2; i < input2.rank; ++i) {
permutation.push(i);
}
permutation.push(1);
input2 = input2.transpose(permutation);
}
return batchFlatten(input2);
});
};
Flatten2.prototype.getConfig = function() {
var config = {};
if (this.dataFormat != null) {
config["dataFormat"] = this.dataFormat;
}
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Flatten2.className = "Flatten";
return Flatten2;
}(Layer);
tfc.serialization.registerClass(Flatten);
var Activation$1 = function(_super) {
__extends(Activation2, _super);
function Activation2(args) {
var _this = _super.call(this, args) || this;
_this.supportsMasking = true;
_this.activation = getActivation(args.activation);
return _this;
}
Activation2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
return _this.activation.apply(input2);
});
};
Activation2.prototype.getConfig = function() {
var config = {activation: serializeActivation(this.activation)};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Activation2.className = "Activation";
return Activation2;
}(Layer);
tfc.serialization.registerClass(Activation$1);
var RepeatVector = function(_super) {
__extends(RepeatVector2, _super);
function RepeatVector2(args) {
var _this = _super.call(this, args) || this;
_this.n = args.n;
_this.inputSpec = [{ndim: 2}];
return _this;
}
RepeatVector2.prototype.computeOutputShape = function(inputShape) {
return [inputShape[0], this.n, inputShape[1]];
};
RepeatVector2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = getExactlyOneTensor(inputs);
return repeat(inputs, _this.n);
});
};
RepeatVector2.prototype.getConfig = function() {
var config = {
n: this.n
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
RepeatVector2.className = "RepeatVector";
return RepeatVector2;
}(Layer);
tfc.serialization.registerClass(RepeatVector);
var Reshape = function(_super) {
__extends(Reshape2, _super);
function Reshape2(args) {
var _this = _super.call(this, args) || this;
_this.targetShape = args.targetShape;
for (var i = 0; i < _this.targetShape.length; ++i) {
if (_this.isUnknown(_this.targetShape[i])) {
_this.targetShape[i] = null;
}
}
return _this;
}
Reshape2.prototype.isUnknown = function(dim) {
return dim < 0 || dim == null;
};
Reshape2.prototype.fixUnknownDimension = function(inputShape, outputShape) {
var errorMsg = "Total size of new array must be unchanged.";
var finalShape = outputShape.slice();
var known = 1;
var unknown = null;
for (var i = 0; i < finalShape.length; ++i) {
var dim = finalShape[i];
if (this.isUnknown(dim)) {
if (unknown === null) {
unknown = i;
} else {
throw new ValueError("Can only specifiy one unknown dimension.");
}
} else {
known *= dim;
}
}
var originalSize = arrayProd(inputShape);
if (unknown !== null) {
if (known === 0 || originalSize % known !== 0) {
throw new ValueError(errorMsg);
}
finalShape[unknown] = originalSize / known;
} else if (originalSize !== known) {
throw new ValueError(errorMsg);
}
return finalShape;
};
Reshape2.prototype.computeOutputShape = function(inputShape) {
var anyUnknownDims = false;
for (var i = 0; i < inputShape.length; ++i) {
if (this.isUnknown(inputShape[i])) {
anyUnknownDims = true;
break;
}
}
if (anyUnknownDims) {
return inputShape.slice(0, 1).concat(this.targetShape);
} else {
return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
}
};
Reshape2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
var inputShape = input2.shape;
var outputShape = inputShape.slice(0, 1).concat(_this.fixUnknownDimension(inputShape.slice(1), _this.targetShape));
return input2.reshape(outputShape);
});
};
Reshape2.prototype.getConfig = function() {
var config = {
targetShape: this.targetShape
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Reshape2.className = "Reshape";
return Reshape2;
}(Layer);
tfc.serialization.registerClass(Reshape);
var Permute = function(_super) {
__extends(Permute2, _super);
function Permute2(args) {
var _this = _super.call(this, args) || this;
if (args.dims == null) {
throw new Error("Required configuration field `dims` is missing during Permute constructor call.");
}
if (!Array.isArray(args.dims)) {
throw new Error("Permute constructor requires `dims` to be an Array, but received " + (args.dims + " instead."));
}
var expectedSortedIndices = range(1, args.dims.length + 1);
if (!tfc.util.arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) {
throw new Error("Invalid permutation `dims`: " + JSON.stringify(args.dims) + " `dims` must contain consecutive integers starting from 1.");
}
_this.dims = args.dims;
_this.dimsIncludingBatch = [0].concat(_this.dims);
_this.inputSpec = [new InputSpec({ndim: _this.dims.length + 1})];
return _this;
}
Permute2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var outputShape = inputShape.slice();
this.dims.forEach(function(dim, i) {
outputShape[i + 1] = inputShape[dim];
});
return outputShape;
};
Permute2.prototype.call = function(inputs, kwargs) {
return tfc.transpose(getExactlyOneTensor(inputs), this.dimsIncludingBatch);
};
Permute2.prototype.getConfig = function() {
var config = {
dims: this.dims
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Permute2.className = "Permute";
return Permute2;
}(Layer);
tfc.serialization.registerClass(Permute);
var Masking = function(_super) {
__extends(Masking2, _super);
function Masking2(args) {
var _this = _super.call(this, args == null ? {} : args) || this;
_this.supportsMasking = true;
if (args != null) {
_this.maskValue = args.maskValue == null ? 0 : args.maskValue;
} else {
_this.maskValue = 0;
}
return _this;
}
Masking2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
Masking2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var config = {maskValue: this.maskValue};
Object.assign(config, baseConfig);
return config;
};
Masking2.prototype.computeMask = function(inputs, mask) {
var input2 = getExactlyOneTensor(inputs);
var axis = -1;
return tfc.any(tfc.notEqual(input2, this.maskValue), axis);
};
Masking2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
var axis = -1;
var keepDims = true;
var booleanMask = tfc.any(tfc.notEqual(input2, _this.maskValue), axis, keepDims);
var output = input2.mul(booleanMask.asType(input2.dtype));
return output;
});
};
Masking2.className = "Masking";
return Masking2;
}(Layer);
tfc.serialization.registerClass(Masking);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var Embedding = function(_super) {
__extends(Embedding2, _super);
function Embedding2(args) {
var _this = _super.call(this, args) || this;
_this.embeddings = null;
_this.DEFAULT_EMBEDDINGS_INITIALIZER = "randomUniform";
if (args.batchInputShape == null && args.inputShape == null) {
var batchSize = null;
if (args.batchSize != null) {
batchSize = args.batchSize;
}
if (args.inputLength == null) {
_this.batchInputShape = [batchSize, null];
} else {
_this.batchInputShape = [batchSize].concat(toList(args.inputLength));
}
}
_this.inputDim = args.inputDim;
assertPositiveInteger(_this.inputDim, "inputDim");
_this.outputDim = args.outputDim;
assertPositiveInteger(_this.outputDim, "outputDim");
_this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || _this.DEFAULT_EMBEDDINGS_INITIALIZER);
_this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
_this.activityRegularizer = getRegularizer(args.activityRegularizer);
_this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
_this.maskZero = args.maskZero;
_this.supportsMasking = args.maskZero;
_this.inputLength = args.inputLength;
return _this;
}
Embedding2.prototype.build = function(inputShape) {
this.embeddings = this.addWeight("embeddings", [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint);
this.built = true;
};
Embedding2.prototype.warnOnIncompatibleInputShape = function(inputShape) {
};
Embedding2.prototype.computeMask = function(inputs, mask) {
var _this = this;
return tfc.tidy(function() {
if (!_this.maskZero) {
return null;
} else {
inputs = getExactlyOneTensor(inputs);
return tfc.notEqual(inputs, tfc.zerosLike(inputs));
}
});
};
Embedding2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
if (this.inputLength == null) {
return inputShape.concat([this.outputDim]);
}
var inLens = toList(this.inputLength);
if (inLens.length !== inputShape.length - 1) {
throw new ValueError('"inputLength" is ' + this.inputLength + ", but received " + ("input shape has shape " + inputShape));
} else {
var i = 0;
for (var k = 0; k < inLens.length; ++k) {
var s1 = inLens[k];
var s2 = inputShape[k + 1];
if (s1 != null && s2 != null && s1 !== s2) {
throw new ValueError('"inputLength" is ' + this.inputLength + ", but received " + ("input shape has shape " + inputShape));
} else if (s1 == null) {
inLens[i] = s2;
}
i++;
}
}
return [inputShape[0]].concat(inLens, [this.outputDim]);
};
Embedding2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
if (input2.dtype !== "int32") {
input2 = cast(input2, "int32");
}
var output = gather(_this.embeddings.read(), input2.as1D());
return output.reshape(getExactlyOneShape(_this.computeOutputShape(input2.shape)));
});
};
Embedding2.prototype.getConfig = function() {
var config = {
inputDim: this.inputDim,
outputDim: this.outputDim,
embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),
embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),
maskZero: this.maskZero,
inputLength: this.inputLength
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Embedding2.className = "Embedding";
return Embedding2;
}(Layer);
tfc.serialization.registerClass(Embedding);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var Merge = function(_super) {
__extends(Merge2, _super);
function Merge2(args) {
var _this = _super.call(this, args || {}) || this;
_this.supportsMasking = true;
return _this;
}
Merge2.prototype.mergeFunction = function(inputs) {
throw new NotImplementedError();
};
Merge2.prototype.computeElementwiseOpOutputShape = function(shape1, shape2) {
if (shape1 == null || shape2 == null) {
return null;
} else if (shape1.length < shape2.length) {
return this.computeElementwiseOpOutputShape(shape2, shape1);
} else if (shape2.length === 0) {
return shape1;
}
var outputShape = shape1.slice(0, shape1.length - shape2.length);
for (var k = 0; k < shape2.length; ++k) {
var i = shape1[shape1.length - shape2.length + k];
var j = shape2[k];
if (i == null || j == null || i < 0 || j < 0) {
outputShape.push(null);
} else if (i === 1) {
outputShape.push(j);
} else if (j === 1) {
outputShape.push(i);
} else {
if (i !== j) {
throw new ValueError("Operands could not be broadcast together with shapes " + JSON.stringify(shape1) + " " + JSON.stringify(shape2));
}
outputShape.push(i);
}
}
return outputShape;
};
Merge2.prototype.build = function(inputShape) {
if (Array.isArray(inputShape) && !Array.isArray(inputShape[0])) {
inputShape = [getExactlyOneShape(inputShape)];
}
inputShape = inputShape;
if (inputShape.length < 2) {
throw new ValueError("A merge layer should be called on an Array of at least 2 inputs." + (" Got " + inputShape.length + " input(s)."));
}
var batchSizes = [];
for (var _i = 0, inputShape_1 = inputShape; _i < inputShape_1.length; _i++) {
var shape = inputShape_1[_i];
if (shape != null && shape[0] !== null) {
batchSizes.push(shape[0]);
}
}
batchSizes = unique(batchSizes);
if (batchSizes.length > 1) {
throw new ValueError("Can not merge tensors with different batch sizes. " + ("Got tensors with shapes: " + JSON.stringify(inputShape) + "."));
}
var outputShape = inputShape[0] == null ? null : inputShape[0].slice(1);
for (var i = 1; i < inputShape.length; ++i) {
var shape = inputShape[i] == null ? null : inputShape[i].slice(1);
outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
}
var allRanks = inputShape.map(function(shape2) {
return shape2.length;
});
if (inputShape.indexOf(null) === -1 && unique(allRanks).length === 1) {
this.reshapeRequired = false;
} else {
this.reshapeRequired = true;
}
};
Merge2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = inputs;
if (_this.reshapeRequired) {
var reshapedInputs = [];
var inputDims = inputs.map(function(input2) {
return input2.rank;
});
if (inputDims.indexOf(null) === -1) {
var maxNDim = max(inputDims);
for (var _i = 0, inputs_1 = inputs; _i < inputs_1.length; _i++) {
var x = inputs_1[_i];
var xNDim = x.rank;
for (var k = 0; k < maxNDim - xNDim; ++k) {
x = expandDims(x, 1);
}
reshapedInputs.push(x);
}
return _this.mergeFunction(reshapedInputs);
} else {
var transposed = false;
for (var _a = 0, inputs_2 = inputs; _a < inputs_2.length; _a++) {
var x = inputs_2[_a];
var xNDim = x.rank;
if (xNDim == null) {
var xShape = x.shape;
var batchSize = xShape[0];
var newShape = xShape.slice(1).concat([batchSize]);
var xTransposed = x.reshape([batchSize].concat(arrayProd(xShape.slice(1))));
xTransposed = tfc.transpose(xTransposed, [1, 0]);
xTransposed = xTransposed.reshape(newShape);
reshapedInputs.push(xTransposed);
transposed = true;
} else if (xNDim > 1) {
var dims = range(1, xNDim).concat([0]);
reshapedInputs.push(tfc.transpose(x, dims));
transposed = true;
} else {
reshapedInputs.push(x);
}
}
var y = _this.mergeFunction(reshapedInputs);
var yNDim = y.rank;
if (transposed) {
if (yNDim == null) {
var yShape = y.shape;
var yNDim_1 = yShape.length;
var batchSize = yShape[yNDim_1 - 1];
var newShape = [batchSize].concat(yShape.slice(0, yShape.length - 1));
y = tfc.transpose(y.reshape([-1, batchSize]), [1, 0]).reshape(newShape);
} else if (yNDim > 1) {
var dims = [yNDim - 1].concat(range(0, yNDim - 1));
y = tfc.transpose(y, dims);
}
}
return y;
}
} else {
return _this.mergeFunction(inputs);
}
});
};
Merge2.prototype.computeOutputShape = function(inputShape) {
inputShape = inputShape;
var outputShape;
if (inputShape[0] == null) {
outputShape = null;
} else {
outputShape = inputShape[0].slice(1);
}
for (var i = 1; i < inputShape.length; ++i) {
var shape = inputShape[i] == null ? null : inputShape[i].slice(1);
outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
}
var batchSizes = [];
for (var _i = 0, inputShape_2 = inputShape; _i < inputShape_2.length; _i++) {
var shape = inputShape_2[_i];
if (shape != null && shape[0] !== null) {
batchSizes.push(shape[0]);
}
}
batchSizes = unique(batchSizes);
if (batchSizes.length === 1) {
outputShape = batchSizes.concat(outputShape);
} else {
outputShape = [null].concat(outputShape);
}
return outputShape;
};
Merge2.prototype.computeMask = function(inputs, mask) {
return tfc.tidy(function() {
if (mask == null) {
return null;
}
if (!Array.isArray(mask)) {
throw new ValueError("`mask` should be an Array");
}
if (!Array.isArray(inputs)) {
throw new ValueError("`inputs` should be an Array");
}
if (mask.length !== inputs.length) {
throw new ValueError("The Array 'inputs' and 'mask' are expected to have the same length, but have different lengths " + ("(" + inputs.length + " vs " + mask.length + ")"));
}
if (mask.every(function(m) {
return m == null;
})) {
return null;
}
mask = mask.map(function(m) {
return m == null ? m : tfc.expandDims(m, 0);
});
var output = mask[0];
for (var i = 1; i < mask.length - 1; ++i) {
output = tfc.logicalAnd(output, mask[i]);
}
return output;
});
};
return Merge2;
}(Layer);
var Add = function(_super) {
__extends(Add2, _super);
function Add2(args) {
return _super.call(this, args) || this;
}
Add2.prototype.mergeFunction = function(inputs) {
return tfc.tidy(function() {
var output = inputs[0].clone();
for (var i = 1; i < inputs.length; ++i) {
output = tfc.add(output, inputs[i]);
}
return output;
});
};
Add2.className = "Add";
return Add2;
}(Merge);
tfc.serialization.registerClass(Add);
var Multiply = function(_super) {
__extends(Multiply2, _super);
function Multiply2(args) {
return _super.call(this, args) || this;
}
Multiply2.prototype.mergeFunction = function(inputs) {
return tfc.tidy(function() {
var output = inputs[0].clone();
for (var i = 1; i < inputs.length; ++i) {
output = tfc.mul(output, inputs[i]);
}
return output;
});
};
Multiply2.className = "Multiply";
return Multiply2;
}(Merge);
tfc.serialization.registerClass(Multiply);
var Average = function(_super) {
__extends(Average2, _super);
function Average2(args) {
return _super.call(this, args) || this;
}
Average2.prototype.mergeFunction = function(inputs) {
return tfc.tidy(function() {
var output = inputs[0].clone();
for (var i = 1; i < inputs.length; ++i) {
output = tfc.add(output, inputs[i]);
}
return tfc.mul(1 / inputs.length, output);
});
};
Average2.className = "Average";
return Average2;
}(Merge);
tfc.serialization.registerClass(Average);
var Maximum = function(_super) {
__extends(Maximum2, _super);
function Maximum2(args) {
return _super.call(this, args) || this;
}
Maximum2.prototype.mergeFunction = function(inputs) {
return tfc.tidy(function() {
var output = inputs[0];
for (var i = 1; i < inputs.length; ++i) {
output = tfc.maximum(output, inputs[i]);
}
return output;
});
};
Maximum2.className = "Maximum";
return Maximum2;
}(Merge);
tfc.serialization.registerClass(Maximum);
var Minimum = function(_super) {
__extends(Minimum2, _super);
function Minimum2(args) {
return _super.call(this, args) || this;
}
Minimum2.prototype.mergeFunction = function(inputs) {
return tfc.tidy(function() {
var output = inputs[0];
for (var i = 1; i < inputs.length; ++i) {
output = tfc.minimum(output, inputs[i]);
}
return output;
});
};
Minimum2.className = "Minimum";
return Minimum2;
}(Merge);
tfc.serialization.registerClass(Minimum);
var Concatenate = function(_super) {
__extends(Concatenate2, _super);
function Concatenate2(args) {
var _this = _super.call(this, args) || this;
_this.DEFAULT_AXIS = -1;
if (args == null) {
args = {};
}
_this.axis = args.axis == null ? _this.DEFAULT_AXIS : args.axis;
_this.supportsMasking = true;
_this.reshapeRequired = false;
return _this;
}
Concatenate2.prototype.build = function(inputShape) {
if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0])) || inputShape.length === 1) {
throw new ValueError("A `Concatenate` layer should be called on a list of at least 2 inputs");
}
inputShape = inputShape;
var allNoneShape = true;
for (var _i = 0, inputShape_3 = inputShape; _i < inputShape_3.length; _i++) {
var shape = inputShape_3[_i];
if (shape != null) {
allNoneShape = false;
break;
}
}
if (allNoneShape) {
return;
}
var shapeSet = [];
for (var i = 0; i < inputShape.length; ++i) {
var shapeWithoutConcatAxis = inputShape[i].slice();
shapeWithoutConcatAxis.splice(this.axis, 1);
var exists = false;
for (var _a = 0, shapeSet_1 = shapeSet; _a < shapeSet_1.length; _a++) {
var shape = shapeSet_1[_a];
if (tfc.util.arraysEqual(shape, shapeWithoutConcatAxis)) {
exists = true;
break;
}
}
if (!exists) {
shapeSet.push(shapeWithoutConcatAxis);
}
}
if (shapeSet.length > 1) {
throw new ValueError("A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got input shapes: " + JSON.stringify(inputShape));
}
};
Concatenate2.prototype.mergeFunction = function(inputs) {
var _this = this;
return tfc.tidy(function() {
return concatenate(inputs, _this.axis);
});
};
Concatenate2.prototype.computeOutputShape = function(inputShape) {
if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0]))) {
throw new ValueError("A `Concatenate` layer should be called on a list of inputs.");
}
var inputShapes = inputShape;
var outputShape = inputShapes[0].slice();
var axis = this.axis < 0 ? outputShape.length + this.axis : this.axis;
for (var _i = 0, _a = inputShapes.slice(1); _i < _a.length; _i++) {
var shape = _a[_i];
if (outputShape[axis] == null || shape[axis] == null) {
outputShape[axis] = null;
break;
}
outputShape[axis] += shape[axis];
}
return outputShape;
};
Concatenate2.prototype.computeMask = function(inputs, mask) {
var _this = this;
if (mask == null) {
return null;
}
if (!Array.isArray(mask)) {
throw new ValueError("`mask` should be an array for Concatenate");
}
if (!Array.isArray(inputs)) {
throw new ValueError("`inputs` should be an array for Concatenate");
}
if (mask.length !== inputs.length) {
throw new ValueError("Mismatch in the length of mask (" + mask.length + ") " + ("and the legnth of inputs (" + inputs.length + ")"));
}
return tfc.tidy(function() {
var allNullMasks = true;
mask.forEach(function(m) {
if (m != null) {
allNullMasks = false;
return;
}
});
if (allNullMasks) {
return null;
}
var outputMasks = [];
for (var i = 0; i < inputs.length; ++i) {
if (mask[i] == null) {
outputMasks.push(tfc.onesLike(inputs[i]).asType("bool"));
} else if (mask[i].rank < inputs[i].rank) {
outputMasks.push(tfc.expandDims(mask[i], -1));
} else {
outputMasks.push(mask[i]);
}
}
var concatenatedMasks = tfc.concat(outputMasks, _this.axis);
return tfc.all(concatenatedMasks, -1, false);
});
};
Concatenate2.prototype.getConfig = function() {
var config = {
axis: this.axis
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Concatenate2.className = "Concatenate";
return Concatenate2;
}(Merge);
tfc.serialization.registerClass(Concatenate);
function interpretAxis(axis, dim) {
while (axis < 0) {
axis += dim;
}
return axis;
}
function batchDot(x, y, axes) {
if (x.shape.length > 3 || y.shape.length > 3) {
throw new NotImplementedError("batchDot is not implemented for tensors of 4D or higher rank yet");
}
tfc.util.assert(x.shape.length >= 2, function() {
return "batchDot requires the rank of x to be >= 2, " + ("but got " + x.shape.length);
});
tfc.util.assert(x.shape.length >= 2, function() {
return "batchDot requires the rank of y to be >= 2, " + ("but got " + y.shape.length);
});
if (typeof axes === "number") {
axes = [axes, axes];
}
if (x.dtype === "complex64" || y.dtype === "complex64") {
throw new NotImplementedError("batchDot is not implemented for complex64-type Tensors yet.");
}
var xNDim = x.shape.length;
var yNDim = y.shape.length;
if (axes == null) {
axes = [xNDim - 1, yNDim - 2];
}
var axesArray = axes;
return tfc.tidy(function() {
var diff;
if (xNDim > yNDim) {
diff = xNDim - yNDim;
var diffShape = [];
for (var i = 0; i < diff; ++i) {
diffShape.push(1);
}
y = y.reshape(y.shape.concat(diffShape));
} else if (yNDim > xNDim) {
diff = yNDim - xNDim;
var diffShape = [];
for (var i = 0; i < diff; ++i) {
diffShape.push(1);
}
x = x.reshape(x.shape.concat(diffShape));
} else {
diff = 0;
}
var out;
if (x.shape.length === 2 && y.shape.length === 2) {
if (axesArray[0] === axesArray[1]) {
out = x.mul(y).sum(axesArray[0]);
} else {
out = x.transpose([1, 0]).mul(y).sum(axesArray[1]);
}
} else {
var adjX = axesArray[0] !== x.shape.length - 1;
var adjY = axesArray[1] === y.shape.length - 1;
out = x.matMul(y, adjX, adjY);
}
if (diff > 0) {
var idx = void 0;
if (xNDim > yNDim) {
idx = xNDim + yNDim - 3;
} else {
idx = xNDim - 1;
}
var squeezeAxes = [];
for (var i = idx; i < idx + diff; ++i) {
squeezeAxes.push(i);
}
out = out.squeeze(squeezeAxes);
}
if (out.shape.length === 1) {
out = out.expandDims(1);
}
return out;
});
}
var Dot = function(_super) {
__extends(Dot2, _super);
function Dot2(args) {
var _this = _super.call(this, args) || this;
_this.axes = args.axes;
_this.normalize = args.normalize == null ? false : args.normalize;
_this.supportsMasking = true;
_this.reshapeRequired = false;
return _this;
}
Dot2.prototype.build = function(inputShape) {
tfc.util.assert(Array.isArray(inputShape) && inputShape.length === 2 && Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), function() {
return "A `Dot` layer should be called on a list of exactly 2 inputs.";
});
var shape1 = inputShape[0];
var shape2 = inputShape[1];
if (shape1.length > 3 || shape2.length > 3) {
throw new NotImplementedError("Dot layer does not support tensors of 4D or higher rank yet.");
}
var axes = this.interpretAxes(shape1, shape2);
if (shape1[axes[0]] !== shape2[axes[1]]) {
throw new ValueError("Dimension incompatibility: " + (shape1[axes[0]] + " !== " + shape2[axes[1]]));
}
};
Dot2.prototype.mergeFunction = function(inputs) {
if (inputs.length !== 2) {
throw new ValueError("A `Dot` layer must be called on exactly 2 inputs, " + ("but received " + inputs.length + " input(s)."));
}
var x1 = inputs[0];
var x2 = inputs[1];
var axes;
if (!Array.isArray(this.axes)) {
axes = [
interpretAxis(this.axes, x1.shape.length),
interpretAxis(this.axes, x2.shape.length)
];
} else {
axes = this.axes.map(function(axis, i) {
return interpretAxis(axis, inputs[i].shape.length);
});
}
if (this.normalize) {
x1 = l2Normalize(x1, axes[0]);
x2 = l2Normalize(x2, axes[1]);
}
return batchDot(x1, x2, axes);
};
Dot2.prototype.interpretAxes = function(shape1, shape2) {
var axes;
if (!Array.isArray(this.axes)) {
axes = [
interpretAxis(this.axes, shape1.length),
interpretAxis(this.axes, shape2.length)
];
} else {
axes = this.axes;
}
return axes;
};
Dot2.prototype.computeOutputShape = function(inputShape) {
tfc.util.assert(Array.isArray(inputShape) && inputShape.length === 2 && Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), function() {
return "A `Dot` layer should be called on a list of exactly 2 inputs.";
});
var shape1 = inputShape[0].slice();
var shape2 = inputShape[1].slice();
if (shape1.length > 3 || shape2.length > 3) {
throw new NotImplementedError("Dot layer does not support tensors of 4D or higher rank yet.");
}
var axes = this.interpretAxes(shape1, shape2);
shape1.splice(axes[0], 1);
shape2.splice(axes[1], 1);
shape2.splice(0, 1);
var outputShape = shape1.concat(shape2);
if (outputShape.length === 1) {
outputShape.push(1);
}
return outputShape;
};
Dot2.prototype.computeMask = function(inputs, mask) {
return null;
};
Dot2.prototype.getConfig = function() {
var config = {
axes: this.axes,
normalize: this.normalize
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Dot2.className = "Dot";
return Dot2;
}(Merge);
tfc.serialization.registerClass(Dot);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var GaussianNoise = function(_super) {
__extends(GaussianNoise2, _super);
function GaussianNoise2(args) {
var _this = _super.call(this, args) || this;
_this.supportsMasking = true;
_this.stddev = args.stddev;
return _this;
}
GaussianNoise2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
GaussianNoise2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var config = {stddev: this.stddev};
Object.assign(config, baseConfig);
return config;
};
GaussianNoise2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
var noised = function() {
return randomNormal(input2.shape, 0, _this.stddev).add(input2);
};
var output = inTrainPhase(noised, function() {
return input2;
}, kwargs["training"] || false);
return output;
});
};
GaussianNoise2.className = "GaussianNoise";
return GaussianNoise2;
}(Layer);
tfc.serialization.registerClass(GaussianNoise);
var GaussianDropout = function(_super) {
__extends(GaussianDropout2, _super);
function GaussianDropout2(args) {
var _this = _super.call(this, args) || this;
_this.supportsMasking = true;
_this.rate = args.rate;
return _this;
}
GaussianDropout2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
GaussianDropout2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var config = {rate: this.rate};
Object.assign(config, baseConfig);
return config;
};
GaussianDropout2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
var input2 = getExactlyOneTensor(inputs);
if (_this.rate > 0 && _this.rate < 1) {
var noised = function() {
var stddev = Math.sqrt(_this.rate / (1 - _this.rate));
return input2.mul(randomNormal(input2.shape, 1, stddev));
};
return inTrainPhase(noised, function() {
return input2;
}, kwargs["training"] || false);
}
return input2;
});
};
GaussianDropout2.className = "GaussianDropout";
return GaussianDropout2;
}(Layer);
tfc.serialization.registerClass(GaussianDropout);
var AlphaDropout = function(_super) {
__extends(AlphaDropout2, _super);
function AlphaDropout2(args) {
var _this = _super.call(this, args) || this;
_this.supportsMasking = true;
_this.rate = args.rate;
_this.noiseShape = args.noiseShape;
return _this;
}
AlphaDropout2.prototype._getNoiseShape = function(inputs) {
return this.noiseShape || getExactlyOneTensor(inputs).shape;
};
AlphaDropout2.prototype.computeOutputShape = function(inputShape) {
return inputShape;
};
AlphaDropout2.prototype.getConfig = function() {
var baseConfig = _super.prototype.getConfig.call(this);
var config = {rate: this.rate};
Object.assign(config, baseConfig);
return config;
};
AlphaDropout2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
if (_this.rate < 1 && _this.rate > 0) {
var noiseShape_1 = _this._getNoiseShape(inputs);
var droppedInputs = function() {
var input2 = getExactlyOneTensor(inputs);
var alpha = 1.6732632423543772;
var scale = 1.0507009873554805;
var alphaP = -alpha * scale;
var keptIdx = tfc.greaterEqual(tfc.randomUniform(noiseShape_1), _this.rate);
keptIdx = cast(keptIdx, "float32");
var a = Math.pow((1 - _this.rate) * (1 + _this.rate * Math.pow(alphaP, 2)), -0.5);
var b = -a * alphaP * _this.rate;
var x = input2.mul(keptIdx).add(keptIdx.add(-1).mul(alphaP));
return x.mul(a).add(b);
};
return inTrainPhase(droppedInputs, function() {
return getExactlyOneTensor(inputs);
}, kwargs["training"] || false);
}
return inputs;
});
};
AlphaDropout2.className = "AlphaDropout";
return AlphaDropout2;
}(Layer);
tfc.serialization.registerClass(AlphaDropout);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function batchNormalization(x, mean, variance, beta, gamma, epsilon2) {
if (epsilon2 === void 0) {
epsilon2 = 1e-3;
}
var out;
if (x.rank === 2) {
out = tfc.batchNorm2d(x, mean, variance, beta, gamma, epsilon2);
} else if (x.rank === 3) {
out = tfc.batchNorm3d(x, mean, variance, beta, gamma, epsilon2);
} else if (x.rank === 4) {
out = tfc.batchNorm4d(x, mean, variance, beta, gamma, epsilon2);
} else {
throw new NotImplementedError("batchNormalization is not implemented for array of rank " + x.rank + " yet");
}
return out;
}
function regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon2) {
if (epsilon2 === void 0) {
epsilon2 = 1e-3;
}
return tfc.tidy(function() {
var meanAndVariance = tfc.moments(x, reductionAxes);
var mean = meanAndVariance.mean;
var variance = meanAndVariance.variance;
var normed = batchNormalization(x, mean, variance, beta, gamma, epsilon2);
return [normed, mean, variance];
});
}
function broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon2) {
if (epsilon2 === void 0) {
epsilon2 = 1e-3;
}
return tfc.tidy(function() {
var meanAndVariance = tfc.moments(x, reductionAxes);
var mean = meanAndVariance.mean;
var variance = meanAndVariance.variance;
var targetShape = [];
for (var _i = 0, _a = range(0, x.rank); _i < _a.length; _i++) {
var axis = _a[_i];
if (reductionAxes.indexOf(axis) !== -1) {
targetShape.push(1);
} else {
targetShape.push(x.shape[axis]);
}
}
var broadcastMean = mean.reshape(targetShape);
var broadcastVariance = variance.reshape(targetShape);
var broadcastGamma = gamma == null ? null : gamma.reshape(targetShape);
var broadcastBeta = beta == null ? null : beta.reshape(targetShape);
var normed = batchNormalization(x, broadcastMean, broadcastVariance, broadcastBeta, broadcastGamma, epsilon2);
return [normed, mean, variance];
});
}
function normalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon2) {
if (epsilon2 === void 0) {
epsilon2 = 1e-3;
}
if (tfc.util.arraysEqual(reductionAxes.slice().sort(), range(0, x.rank - 1))) {
return regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon2);
} else {
return broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon2);
}
}
var BatchNormalization = function(_super) {
__extends(BatchNormalization2, _super);
function BatchNormalization2(args) {
var _this = this;
if (args == null) {
args = {};
}
_this = _super.call(this, args) || this;
_this.supportsMasking = true;
_this.axis = args.axis == null ? -1 : args.axis;
_this.momentum = args.momentum == null ? 0.99 : args.momentum;
_this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
_this.center = args.center == null ? true : args.center;
_this.scale = args.scale == null ? true : args.scale;
_this.betaInitializer = getInitializer(args.betaInitializer || "zeros");
_this.gammaInitializer = getInitializer(args.gammaInitializer || "ones");
_this.movingMeanInitializer = getInitializer(args.movingMeanInitializer || "zeros");
_this.movingVarianceInitializer = getInitializer(args.movingVarianceInitializer || "ones");
_this.betaConstraint = getConstraint(args.betaConstraint);
_this.gammaConstraint = getConstraint(args.gammaConstraint);
_this.betaRegularizer = getRegularizer(args.betaRegularizer);
_this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
return _this;
}
BatchNormalization2.prototype.build = function(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
var axis = this.axis >= 0 ? this.axis : this.axis + inputShape.length;
var dim = inputShape[axis];
if (dim == null) {
throw new ValueError("Axis " + axis + " of input tensor should have a defined dimension but the layer received an input with shape " + (JSON.stringify(inputShape) + "."));
}
this.inputSpec = [new InputSpec({ndim: inputShape.length, axes: (_a = {}, _a[axis] = dim, _a)})];
var shape = [dim];
if (this.scale) {
this.gamma = this.addWeight("gamma", shape, null, this.gammaInitializer, this.gammaRegularizer, true, this.gammaConstraint);
}
if (this.center) {
this.beta = this.addWeight("beta", shape, null, this.betaInitializer, this.betaRegularizer, true, this.betaConstraint);
}
this.movingMean = this.addWeight("moving_mean", shape, null, this.movingMeanInitializer, null, false);
this.movingVariance = this.addWeight("moving_variance", shape, null, this.movingVarianceInitializer, null, false);
this.built = true;
};
BatchNormalization2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
var training = kwargs["training"] == null ? false : kwargs["training"];
var input2 = getExactlyOneTensor(inputs);
var inputShape = input2.shape;
var ndim = inputShape.length;
var reductionAxes = range(0, ndim);
var axis = _this.axis >= 0 ? _this.axis : _this.axis + ndim;
reductionAxes.splice(axis, 1);
var broadcastShape = pyListRepeat(1, ndim);
broadcastShape[axis] = inputShape[axis];
var sortedReductionAxes = reductionAxes.slice();
sortedReductionAxes.sort();
var needsBroadcasting = !tfc.util.arraysEqual(sortedReductionAxes, range(0, ndim).slice(0, ndim - 1));
var normalizeInference = function() {
if (needsBroadcasting) {
var broadcastMovingMean = _this.movingMean.read().reshape(broadcastShape);
var broadcastMovingVariance = _this.movingVariance.read().reshape(broadcastShape);
var broadcastBeta = _this.center ? _this.beta.read().reshape(broadcastShape) : null;
var broadcastGamma = _this.scale ? _this.gamma.read().reshape(broadcastShape) : null;
return batchNormalization(input2, broadcastMovingMean, broadcastMovingVariance, broadcastBeta, broadcastGamma, _this.epsilon);
} else {
return batchNormalization(input2, _this.movingMean.read(), _this.movingVariance.read(), _this.beta == null ? null : _this.beta.read(), _this.gamma == null ? null : _this.gamma.read(), _this.epsilon);
}
};
if (!training) {
return normalizeInference();
}
var _a = normalizeBatchInTraining(input2, _this.gamma.read(), _this.beta.read(), reductionAxes, _this.epsilon), normedTraining = _a[0], mean = _a[1], variance = _a[2];
var doMovingAverage = function(variable, value, momentum) {
tfc.tidy(function() {
var decay = 1 - momentum;
var origValue = variable.read();
var updateDelta = origValue.sub(value).mul(decay);
variable.write(origValue.sub(updateDelta));
});
};
var updateMovingMeanAndVariance = function() {
doMovingAverage(_this.movingMean, mean, _this.momentum);
doMovingAverage(_this.movingVariance, variance, _this.momentum);
};
updateMovingMeanAndVariance();
return normedTraining;
});
};
BatchNormalization2.prototype.getConfig = function() {
var config = {
axis: this.axis,
momentum: this.momentum,
epsilon: this.epsilon,
center: this.center,
scale: this.scale,
betaInitializer: serializeInitializer(this.betaInitializer),
gammaInitializer: serializeInitializer(this.gammaInitializer),
movingMeanInitializer: serializeInitializer(this.movingMeanInitializer),
movingVarianceInitializer: serializeInitializer(this.movingVarianceInitializer),
betaRegularizer: serializeRegularizer(this.betaRegularizer),
gammaRegularizer: serializeRegularizer(this.gammaRegularizer),
betaConstraint: serializeConstraint(this.betaConstraint),
gammaConstraint: serializeConstraint(this.gammaConstraint)
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
BatchNormalization2.className = "BatchNormalization";
return BatchNormalization2;
}(Layer);
tfc.serialization.registerClass(BatchNormalization);
var LayerNormalization = function(_super) {
__extends(LayerNormalization2, _super);
function LayerNormalization2(args) {
var _this = this;
if (args == null) {
args = {};
}
_this = _super.call(this, args) || this;
_this.axis = args.axis == null ? -1 : args.axis;
if (typeof _this.axis === "number") {
if (!Number.isInteger(_this.axis)) {
throw new Error("Expected axis to be an integer, but received " + _this.axis);
}
} else if (Array.isArray(_this.axis)) {
for (var _i = 0, _a = _this.axis; _i < _a.length; _i++) {
var axis = _a[_i];
if (!Number.isInteger(axis)) {
throw new Error("Expected axis to be an array of integers, " + ("but received " + JSON.stringify(_this.axis)));
}
}
} else {
throw new Error("Expected axis to be an integer or an array of integers, " + ("but received " + JSON.stringify(_this.axis)));
}
_this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
_this.center = args.center == null ? true : args.center;
_this.scale = args.scale == null ? true : args.scale;
_this.betaInitializer = getInitializer(args.betaInitializer || "zeros");
_this.gammaInitializer = getInitializer(args.gammaInitializer || "ones");
_this.betaRegularizer = getRegularizer(args.betaRegularizer);
_this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
_this.supportsMasking = true;
return _this;
}
LayerNormalization2.prototype.build = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var nDims = inputShape.length;
if (typeof this.axis === "number") {
this.axis = [this.axis];
}
for (var i = 0; i < this.axis.length; ++i) {
if (this.axis[i] < 0) {
this.axis[i] += nDims;
}
}
for (var _i = 0, _a = this.axis; _i < _a.length; _i++) {
var axis = _a[_i];
if (axis < 0 || axis >= nDims) {
throw new Error("Invalid axis: " + axis);
}
}
if (this.axis.length !== unique(this.axis).length) {
throw new Error("Found duplicate axes in: " + this.axis);
}
var paramShape = this.axis.map(function(axis2) {
return inputShape[axis2];
});
var trainable = true;
if (this.scale) {
this.gamma = this.addWeight("gamma", paramShape, "float32", this.gammaInitializer, this.gammaRegularizer, trainable);
} else {
this.gamma = null;
}
if (this.center) {
this.beta = this.addWeight("beta", paramShape, "float32", this.betaInitializer, this.betaRegularizer, trainable);
} else {
this.beta = null;
}
this.built = true;
};
LayerNormalization2.prototype.call = function(inputs, kwargs) {
var _this = this;
var input2 = getExactlyOneTensor(inputs);
var inputShape = input2.shape;
var nDims = inputShape.length;
return tfc.tidy(function() {
var keepDims = true;
var _a = tfc.moments(input2, _this.axis, keepDims), mean = _a.mean, variance = _a.variance;
var broadcastShape = pyListRepeat(1, nDims);
for (var _i = 0, _b = _this.axis; _i < _b.length; _i++) {
var dim = _b[_i];
broadcastShape[dim] = inputShape[dim];
}
var broadcast = function(v) {
if (v != null && v.shape.length !== nDims && _this.axis !== [nDims - 1]) {
return v.reshape(broadcastShape);
} else {
return v;
}
};
var scale = broadcast(_this.gamma.read());
var offset = broadcast(_this.beta.read());
var momentsTiling = [];
var scaleOffsetTiling = [];
for (var i = 0; i < nDims; ++i) {
if (_this.axis.indexOf(i) !== -1) {
momentsTiling.push(inputShape[i]);
scaleOffsetTiling.push(1);
} else {
momentsTiling.push(1);
scaleOffsetTiling.push(inputShape[i]);
}
}
mean = mean.tile(momentsTiling);
variance = variance.tile(momentsTiling);
scale = scale.tile(scaleOffsetTiling);
offset = offset.tile(scaleOffsetTiling);
return batchNormalization(input2, mean, variance, offset, scale, _this.epsilon);
});
};
LayerNormalization2.prototype.getConfig = function() {
var config = {
axis: this.axis,
epsilon: this.epsilon,
center: this.center,
scale: this.scale,
betaInitializer: serializeInitializer(this.betaInitializer),
gammaInitializer: serializeInitializer(this.gammaInitializer),
betaRegularizer: serializeRegularizer(this.betaRegularizer),
gammaRegularizer: serializeRegularizer(this.gammaRegularizer)
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
LayerNormalization2.className = "LayerNormalization";
return LayerNormalization2;
}(Layer);
tfc.serialization.registerClass(LayerNormalization);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function spatial2dPadding(x, padding, dataFormat) {
return tfc.tidy(function() {
if (x.rank !== 4) {
throw new ValueError("temporalPadding expects input tensor to be 4-D, but received a " + (x.rank + "-D tensor."));
}
if (padding == null) {
padding = [[1, 1], [1, 1]];
}
if (padding.length !== 2 || padding[0].length !== 2 || padding[1].length !== 2) {
throw new ValueError("spatial2dPadding expects `padding` to be an Array of two Arrays, each of which is an Array of two integers.");
}
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
if (dataFormat !== "channelsLast" && dataFormat !== "channelsFirst") {
throw new ValueError("Unknown data format: " + dataFormat + ". Supported data formats are 'channelsLast' and 'channelsFirst.");
}
var pattern;
if (dataFormat === "channelsFirst") {
pattern = [[0, 0], [0, 0], padding[0], padding[1]];
} else {
pattern = [[0, 0], padding[0], padding[1], [0, 0]];
}
return tfc.pad(x, pattern);
});
}
var ZeroPadding2D = function(_super) {
__extends(ZeroPadding2D2, _super);
function ZeroPadding2D2(args) {
var _this = this;
if (args == null) {
args = {};
}
_this = _super.call(this, args) || this;
_this.dataFormat = args.dataFormat == null ? imageDataFormat() : args.dataFormat;
if (args.padding == null) {
_this.padding = [[1, 1], [1, 1]];
} else if (typeof args.padding === "number") {
_this.padding = [[args.padding, args.padding], [args.padding, args.padding]];
} else {
args.padding = args.padding;
if (args.padding.length !== 2) {
throw new ValueError("ZeroPadding2D expects padding to be a length-2 array, but " + ("received a length-" + args.padding.length + " array."));
}
var heightPadding = void 0;
var widthPadding = void 0;
if (typeof args.padding[0] === "number") {
heightPadding = [args.padding[0], args.padding[0]];
widthPadding = [args.padding[1], args.padding[1]];
} else {
args.padding = args.padding;
if (args.padding[0].length !== 2) {
throw new ValueError("ZeroPadding2D expects height padding to be a length-2 array, " + ("but received a length-" + args.padding[0].length + " array."));
}
heightPadding = args.padding[0];
if (args.padding[1].length !== 2) {
throw new ValueError("ZeroPadding2D expects width padding to be a length-2 array, " + ("but received a length-" + args.padding[1].length + " array."));
}
widthPadding = args.padding[1];
}
_this.padding = [heightPadding, widthPadding];
}
_this.inputSpec = [new InputSpec({ndim: 4})];
return _this;
}
ZeroPadding2D2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var rows;
var cols;
if (this.dataFormat === "channelsFirst") {
if (inputShape[2] != null && inputShape[2] >= 0) {
rows = inputShape[2] + this.padding[0][0] + this.padding[0][1];
} else {
rows = null;
}
if (inputShape[3] != null && inputShape[3] >= 0) {
cols = inputShape[3] + this.padding[1][0] + this.padding[1][1];
} else {
cols = null;
}
return [inputShape[0], inputShape[1], rows, cols];
} else {
if (inputShape[1] != null && inputShape[1] >= 0) {
rows = inputShape[1] + this.padding[0][0] + this.padding[0][1];
} else {
rows = null;
}
if (inputShape[2] != null && inputShape[2] >= 0) {
cols = inputShape[2] + this.padding[1][0] + this.padding[1][1];
} else {
cols = null;
}
return [inputShape[0], rows, cols, inputShape[3]];
}
};
ZeroPadding2D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
return spatial2dPadding(getExactlyOneTensor(inputs), _this.padding, _this.dataFormat);
});
};
ZeroPadding2D2.prototype.getConfig = function() {
var config = {
padding: this.padding,
dataFormat: this.dataFormat
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
ZeroPadding2D2.className = "ZeroPadding2D";
return ZeroPadding2D2;
}(Layer);
tfc.serialization.registerClass(ZeroPadding2D);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function pool2d(x, poolSize, strides, padding, dataFormat, poolMode) {
return tfc.tidy(function() {
checkDataFormat(dataFormat);
checkPoolMode(poolMode);
checkPaddingMode(padding);
if (strides == null) {
strides = [1, 1];
}
if (padding == null) {
padding = "valid";
}
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
if (poolMode == null) {
poolMode = "max";
}
x = preprocessConv2DInput(x, dataFormat);
var y;
var paddingString = padding === "same" ? "same" : "valid";
if (poolMode === "max") {
y = tfc.maxPool(x, poolSize, strides, paddingString);
} else {
y = tfc.avgPool(x, poolSize, strides, paddingString);
}
if (dataFormat === "channelsFirst") {
y = tfc.transpose(y, [0, 3, 1, 2]);
}
return y;
});
}
function pool3d(x, poolSize, strides, padding, dataFormat, poolMode) {
return tfc.tidy(function() {
checkDataFormat(dataFormat);
checkPoolMode(poolMode);
checkPaddingMode(padding);
if (strides == null) {
strides = [1, 1, 1];
}
if (padding == null) {
padding = "valid";
}
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
if (poolMode == null) {
poolMode = "max";
}
x = preprocessConv3DInput(x, dataFormat);
var y;
var paddingString = padding === "same" ? "same" : "valid";
if (poolMode === "max") {
y = tfc.maxPool3d(x, poolSize, strides, paddingString);
} else {
y = tfc.avgPool3d(x, poolSize, strides, paddingString);
}
if (dataFormat === "channelsFirst") {
y = tfc.transpose(y, [0, 4, 1, 2, 3]);
}
return y;
});
}
var Pooling1D = function(_super) {
__extends(Pooling1D2, _super);
function Pooling1D2(args) {
var _this = this;
if (args.poolSize == null) {
args.poolSize = 2;
}
_this = _super.call(this, args) || this;
if (typeof args.poolSize === "number") {
_this.poolSize = [args.poolSize];
} else if (Array.isArray(args.poolSize) && args.poolSize.length === 1 && typeof args.poolSize[0] === "number") {
_this.poolSize = args.poolSize;
} else {
throw new ValueError("poolSize for 1D convolutional layer must be a number or an Array of a single number, but received " + ("" + JSON.stringify(args.poolSize)));
}
assertPositiveInteger(_this.poolSize, "poolSize");
if (args.strides == null) {
_this.strides = _this.poolSize;
} else {
if (typeof args.strides === "number") {
_this.strides = [args.strides];
} else if (Array.isArray(args.strides) && args.strides.length === 1 && typeof args.strides[0] === "number") {
_this.strides = args.strides;
} else {
throw new ValueError("strides for 1D convolutional layer must be a number or an Array of a single number, but received " + ("" + JSON.stringify(args.strides)));
}
}
assertPositiveInteger(_this.strides, "strides");
_this.padding = args.padding == null ? "valid" : args.padding;
checkPaddingMode(_this.padding);
_this.inputSpec = [new InputSpec({ndim: 3})];
return _this;
}
Pooling1D2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var length = convOutputLength(inputShape[1], this.poolSize[0], this.padding, this.strides[0]);
return [inputShape[0], length, inputShape[2]];
};
Pooling1D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
inputs = expandDims(getExactlyOneTensor(inputs), 2);
var output = _this.poolingFunction(getExactlyOneTensor(inputs), [_this.poolSize[0], 1], [_this.strides[0], 1], _this.padding, "channelsLast");
return tfc.squeeze(output, [2]);
});
};
Pooling1D2.prototype.getConfig = function() {
var config = {
poolSize: this.poolSize,
padding: this.padding,
strides: this.strides
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Pooling1D2;
}(Layer);
var MaxPooling1D = function(_super) {
__extends(MaxPooling1D2, _super);
function MaxPooling1D2(args) {
return _super.call(this, args) || this;
}
MaxPooling1D2.prototype.poolingFunction = function(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool2d(inputs, poolSize, strides, padding, dataFormat, "max");
};
MaxPooling1D2.className = "MaxPooling1D";
return MaxPooling1D2;
}(Pooling1D);
tfc.serialization.registerClass(MaxPooling1D);
var AveragePooling1D = function(_super) {
__extends(AveragePooling1D2, _super);
function AveragePooling1D2(args) {
return _super.call(this, args) || this;
}
AveragePooling1D2.prototype.poolingFunction = function(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool2d(inputs, poolSize, strides, padding, dataFormat, "avg");
};
AveragePooling1D2.className = "AveragePooling1D";
return AveragePooling1D2;
}(Pooling1D);
tfc.serialization.registerClass(AveragePooling1D);
var Pooling2D = function(_super) {
__extends(Pooling2D2, _super);
function Pooling2D2(args) {
var _this = this;
if (args.poolSize == null) {
args.poolSize = [2, 2];
}
_this = _super.call(this, args) || this;
_this.poolSize = Array.isArray(args.poolSize) ? args.poolSize : [args.poolSize, args.poolSize];
if (args.strides == null) {
_this.strides = _this.poolSize;
} else if (Array.isArray(args.strides)) {
if (args.strides.length !== 2) {
throw new ValueError("If the strides property of a 2D pooling layer is an Array, it is expected to have a length of 2, but received length " + (args.strides.length + "."));
}
_this.strides = args.strides;
} else {
_this.strides = [args.strides, args.strides];
}
assertPositiveInteger(_this.poolSize, "poolSize");
assertPositiveInteger(_this.strides, "strides");
_this.padding = args.padding == null ? "valid" : args.padding;
_this.dataFormat = args.dataFormat == null ? "channelsLast" : args.dataFormat;
checkDataFormat(_this.dataFormat);
checkPaddingMode(_this.padding);
_this.inputSpec = [new InputSpec({ndim: 4})];
return _this;
}
Pooling2D2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var rows = this.dataFormat === "channelsFirst" ? inputShape[2] : inputShape[1];
var cols = this.dataFormat === "channelsFirst" ? inputShape[3] : inputShape[2];
rows = convOutputLength(rows, this.poolSize[0], this.padding, this.strides[0]);
cols = convOutputLength(cols, this.poolSize[1], this.padding, this.strides[1]);
if (this.dataFormat === "channelsFirst") {
return [inputShape[0], inputShape[1], rows, cols];
} else {
return [inputShape[0], rows, cols, inputShape[3]];
}
};
Pooling2D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
return _this.poolingFunction(getExactlyOneTensor(inputs), _this.poolSize, _this.strides, _this.padding, _this.dataFormat);
});
};
Pooling2D2.prototype.getConfig = function() {
var config = {
poolSize: this.poolSize,
padding: this.padding,
strides: this.strides,
dataFormat: this.dataFormat
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Pooling2D2;
}(Layer);
var MaxPooling2D = function(_super) {
__extends(MaxPooling2D2, _super);
function MaxPooling2D2(args) {
return _super.call(this, args) || this;
}
MaxPooling2D2.prototype.poolingFunction = function(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool2d(inputs, poolSize, strides, padding, dataFormat, "max");
};
MaxPooling2D2.className = "MaxPooling2D";
return MaxPooling2D2;
}(Pooling2D);
tfc.serialization.registerClass(MaxPooling2D);
var AveragePooling2D = function(_super) {
__extends(AveragePooling2D2, _super);
function AveragePooling2D2(args) {
return _super.call(this, args) || this;
}
AveragePooling2D2.prototype.poolingFunction = function(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool2d(inputs, poolSize, strides, padding, dataFormat, "avg");
};
AveragePooling2D2.className = "AveragePooling2D";
return AveragePooling2D2;
}(Pooling2D);
tfc.serialization.registerClass(AveragePooling2D);
var Pooling3D = function(_super) {
__extends(Pooling3D2, _super);
function Pooling3D2(args) {
var _this = this;
if (args.poolSize == null) {
args.poolSize = [2, 2, 2];
}
_this = _super.call(this, args) || this;
_this.poolSize = Array.isArray(args.poolSize) ? args.poolSize : [args.poolSize, args.poolSize, args.poolSize];
if (args.strides == null) {
_this.strides = _this.poolSize;
} else if (Array.isArray(args.strides)) {
if (args.strides.length !== 3) {
throw new ValueError("If the strides property of a 3D pooling layer is an Array, it is expected to have a length of 3, but received length " + (args.strides.length + "."));
}
_this.strides = args.strides;
} else {
_this.strides = [args.strides, args.strides, args.strides];
}
assertPositiveInteger(_this.poolSize, "poolSize");
assertPositiveInteger(_this.strides, "strides");
_this.padding = args.padding == null ? "valid" : args.padding;
_this.dataFormat = args.dataFormat == null ? "channelsLast" : args.dataFormat;
checkDataFormat(_this.dataFormat);
checkPaddingMode(_this.padding);
_this.inputSpec = [new InputSpec({ndim: 5})];
return _this;
}
Pooling3D2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var depths = this.dataFormat === "channelsFirst" ? inputShape[2] : inputShape[1];
var rows = this.dataFormat === "channelsFirst" ? inputShape[3] : inputShape[2];
var cols = this.dataFormat === "channelsFirst" ? inputShape[4] : inputShape[3];
depths = convOutputLength(depths, this.poolSize[0], this.padding, this.strides[0]);
rows = convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]);
cols = convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]);
if (this.dataFormat === "channelsFirst") {
return [inputShape[0], inputShape[1], depths, rows, cols];
} else {
return [inputShape[0], depths, rows, cols, inputShape[4]];
}
};
Pooling3D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
_this.invokeCallHook(inputs, kwargs);
return _this.poolingFunction(getExactlyOneTensor(inputs), _this.poolSize, _this.strides, _this.padding, _this.dataFormat);
});
};
Pooling3D2.prototype.getConfig = function() {
var config = {
poolSize: this.poolSize,
padding: this.padding,
strides: this.strides,
dataFormat: this.dataFormat
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Pooling3D2;
}(Layer);
var MaxPooling3D = function(_super) {
__extends(MaxPooling3D2, _super);
function MaxPooling3D2(args) {
return _super.call(this, args) || this;
}
MaxPooling3D2.prototype.poolingFunction = function(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool3d(inputs, poolSize, strides, padding, dataFormat, "max");
};
MaxPooling3D2.className = "MaxPooling3D";
return MaxPooling3D2;
}(Pooling3D);
tfc.serialization.registerClass(MaxPooling3D);
var AveragePooling3D = function(_super) {
__extends(AveragePooling3D2, _super);
function AveragePooling3D2(args) {
return _super.call(this, args) || this;
}
AveragePooling3D2.prototype.poolingFunction = function(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool3d(inputs, poolSize, strides, padding, dataFormat, "avg");
};
AveragePooling3D2.className = "AveragePooling3D";
return AveragePooling3D2;
}(Pooling3D);
tfc.serialization.registerClass(AveragePooling3D);
var GlobalPooling1D = function(_super) {
__extends(GlobalPooling1D2, _super);
function GlobalPooling1D2(args) {
var _this = _super.call(this, args) || this;
_this.inputSpec = [new InputSpec({ndim: 3})];
return _this;
}
GlobalPooling1D2.prototype.computeOutputShape = function(inputShape) {
return [inputShape[0], inputShape[2]];
};
GlobalPooling1D2.prototype.call = function(inputs, kwargs) {
throw new NotImplementedError();
};
return GlobalPooling1D2;
}(Layer);
var GlobalAveragePooling1D = function(_super) {
__extends(GlobalAveragePooling1D2, _super);
function GlobalAveragePooling1D2(args) {
return _super.call(this, args || {}) || this;
}
GlobalAveragePooling1D2.prototype.call = function(inputs, kwargs) {
return tfc.tidy(function() {
var input2 = getExactlyOneTensor(inputs);
return tfc.mean(input2, 1);
});
};
GlobalAveragePooling1D2.className = "GlobalAveragePooling1D";
return GlobalAveragePooling1D2;
}(GlobalPooling1D);
tfc.serialization.registerClass(GlobalAveragePooling1D);
var GlobalMaxPooling1D = function(_super) {
__extends(GlobalMaxPooling1D2, _super);
function GlobalMaxPooling1D2(args) {
return _super.call(this, args || {}) || this;
}
GlobalMaxPooling1D2.prototype.call = function(inputs, kwargs) {
return tfc.tidy(function() {
var input2 = getExactlyOneTensor(inputs);
return tfc.max(input2, 1);
});
};
GlobalMaxPooling1D2.className = "GlobalMaxPooling1D";
return GlobalMaxPooling1D2;
}(GlobalPooling1D);
tfc.serialization.registerClass(GlobalMaxPooling1D);
var GlobalPooling2D = function(_super) {
__extends(GlobalPooling2D2, _super);
function GlobalPooling2D2(args) {
var _this = _super.call(this, args) || this;
_this.dataFormat = args.dataFormat == null ? "channelsLast" : args.dataFormat;
checkDataFormat(_this.dataFormat);
_this.inputSpec = [new InputSpec({ndim: 4})];
return _this;
}
GlobalPooling2D2.prototype.computeOutputShape = function(inputShape) {
inputShape = inputShape;
if (this.dataFormat === "channelsLast") {
return [inputShape[0], inputShape[3]];
} else {
return [inputShape[0], inputShape[1]];
}
};
GlobalPooling2D2.prototype.call = function(inputs, kwargs) {
throw new NotImplementedError();
};
GlobalPooling2D2.prototype.getConfig = function() {
var config = {dataFormat: this.dataFormat};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return GlobalPooling2D2;
}(Layer);
var GlobalAveragePooling2D = function(_super) {
__extends(GlobalAveragePooling2D2, _super);
function GlobalAveragePooling2D2() {
return _super !== null && _super.apply(this, arguments) || this;
}
GlobalAveragePooling2D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
var input2 = getExactlyOneTensor(inputs);
if (_this.dataFormat === "channelsLast") {
return tfc.mean(input2, [1, 2]);
} else {
return tfc.mean(input2, [2, 3]);
}
});
};
GlobalAveragePooling2D2.className = "GlobalAveragePooling2D";
return GlobalAveragePooling2D2;
}(GlobalPooling2D);
tfc.serialization.registerClass(GlobalAveragePooling2D);
var GlobalMaxPooling2D = function(_super) {
__extends(GlobalMaxPooling2D2, _super);
function GlobalMaxPooling2D2() {
return _super !== null && _super.apply(this, arguments) || this;
}
GlobalMaxPooling2D2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
var input2 = getExactlyOneTensor(inputs);
if (_this.dataFormat === "channelsLast") {
return tfc.max(input2, [1, 2]);
} else {
return tfc.max(input2, [2, 3]);
}
});
};
GlobalMaxPooling2D2.className = "GlobalMaxPooling2D";
return GlobalMaxPooling2D2;
}(GlobalPooling2D);
tfc.serialization.registerClass(GlobalMaxPooling2D);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var Wrapper = function(_super) {
__extends(Wrapper2, _super);
function Wrapper2(args) {
var _this = _super.call(this, args) || this;
_this.layer = args.layer;
return _this;
}
Wrapper2.prototype.build = function(inputShape) {
this.built = true;
};
Object.defineProperty(Wrapper2.prototype, "trainable", {
get: function() {
if (this.layer != null) {
return this.layer.trainable;
} else {
return false;
}
},
set: function(value) {
if (this.layer != null) {
this.layer.trainable = value;
}
},
enumerable: true,
configurable: true
});
Object.defineProperty(Wrapper2.prototype, "trainableWeights", {
get: function() {
return this.layer.trainableWeights;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Wrapper2.prototype, "nonTrainableWeights", {
get: function() {
return this.layer.nonTrainableWeights;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Wrapper2.prototype, "updates", {
get: function() {
return this.layer._updates;
},
enumerable: true,
configurable: true
});
Object.defineProperty(Wrapper2.prototype, "losses", {
get: function() {
return this.layer.losses;
},
enumerable: true,
configurable: true
});
Wrapper2.prototype.getWeights = function() {
return this.layer.getWeights();
};
Wrapper2.prototype.setWeights = function(weights) {
this.layer.setWeights(weights);
};
Wrapper2.prototype.getConfig = function() {
var config = {
layer: {
className: this.layer.getClassName(),
config: this.layer.getConfig()
}
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Wrapper2.prototype.setFastWeightInitDuringBuild = function(value) {
_super.prototype.setFastWeightInitDuringBuild.call(this, value);
if (this.layer != null) {
this.layer.setFastWeightInitDuringBuild(value);
}
};
Wrapper2.fromConfig = function(cls, config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
var layerConfig = config["layer"];
var layer = deserialize(layerConfig, customObjects);
delete config["layer"];
var newConfig = {layer};
Object.assign(newConfig, config);
return new cls(newConfig);
};
return Wrapper2;
}(Layer);
var TimeDistributed = function(_super) {
__extends(TimeDistributed2, _super);
function TimeDistributed2(args) {
var _this = _super.call(this, args) || this;
_this.supportsMasking = true;
return _this;
}
TimeDistributed2.prototype.build = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length < 3) {
throw new ValueError("TimeDistributed layer expects an input shape >= 3D, but received " + ("input shape " + JSON.stringify(inputShape)));
}
this.inputSpec = [{shape: inputShape}];
var childInputShape = [inputShape[0]].concat(inputShape.slice(2));
if (!this.layer.built) {
this.layer.build(childInputShape);
this.layer.built = true;
}
_super.prototype.build.call(this, inputShape);
};
TimeDistributed2.prototype.computeOutputShape = function(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var childInputShape = [inputShape[0]].concat(inputShape.slice(2));
var childOutputShape = this.layer.computeOutputShape(childInputShape);
var timesteps = inputShape[1];
return [childOutputShape[0], timesteps].concat(childOutputShape.slice(1));
};
TimeDistributed2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
inputs = getExactlyOneTensor(inputs);
var step = function(inputs2, states) {
var output = getExactlyOneTensor(_this.layer.call(inputs2, kwargs));
return [output, []];
};
var rnnOutputs = rnn(step, inputs, [], false, null, null, false, true);
var y = rnnOutputs[1];
return y;
});
};
TimeDistributed2.className = "TimeDistributed";
return TimeDistributed2;
}(Wrapper);
tfc.serialization.registerClass(TimeDistributed);
function checkBidirectionalMergeMode(value) {
checkStringTypeUnionValue(VALID_BIDIRECTIONAL_MERGE_MODES, "BidirectionalMergeMode", value);
}
var DEFAULT_BIDIRECTIONAL_MERGE_MODE = "concat";
var Bidirectional = function(_super) {
__extends(Bidirectional2, _super);
function Bidirectional2(args) {
var _this = _super.call(this, args) || this;
var layerConfig = args.layer.getConfig();
var forwDict = {};
forwDict["className"] = args.layer.getClassName();
forwDict["config"] = layerConfig;
_this.forwardLayer = deserialize(forwDict);
layerConfig["goBackwards"] = layerConfig["goBackwards"] === true ? false : true;
var backDict = {};
backDict["className"] = args.layer.getClassName();
backDict["config"] = layerConfig;
_this.backwardLayer = deserialize(backDict);
_this.forwardLayer.name = "forward_" + _this.forwardLayer.name;
_this.backwardLayer.name = "backward_" + _this.backwardLayer.name;
_this.mergeMode = args.mergeMode === void 0 ? DEFAULT_BIDIRECTIONAL_MERGE_MODE : args.mergeMode;
checkBidirectionalMergeMode(_this.mergeMode);
if (args.weights) {
throw new NotImplementedError("weights support is not implemented for Bidirectional layer yet.");
}
_this._stateful = args.layer.stateful;
_this.returnSequences = args.layer.returnSequences;
_this.returnState = args.layer.returnState;
_this.supportsMasking = true;
_this._trainable = true;
_this.inputSpec = args.layer.inputSpec;
_this.numConstants = null;
return _this;
}
Object.defineProperty(Bidirectional2.prototype, "trainable", {
get: function() {
return this._trainable;
},
set: function(value) {
this._trainable = value;
if (this.forwardLayer != null) {
this.forwardLayer.trainable = value;
}
if (this.backwardLayer != null) {
this.backwardLayer.trainable = value;
}
},
enumerable: true,
configurable: true
});
Bidirectional2.prototype.getWeights = function() {
return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights());
};
Bidirectional2.prototype.setWeights = function(weights) {
var numWeights = weights.length;
var numeightsOver2 = Math.floor(numWeights / 2);
this.forwardLayer.setWeights(weights.slice(0, numeightsOver2));
this.backwardLayer.setWeights(weights.slice(numeightsOver2));
};
Bidirectional2.prototype.computeOutputShape = function(inputShape) {
var layerShapes = this.forwardLayer.computeOutputShape(inputShape);
if (!(Array.isArray(layerShapes) && Array.isArray(layerShapes[0]))) {
layerShapes = [layerShapes];
}
layerShapes = layerShapes;
var outputShape;
var outputShapes;
var stateShape;
if (this.returnState) {
stateShape = layerShapes.slice(1);
outputShape = layerShapes[0];
} else {
outputShape = layerShapes[0];
}
outputShape = outputShape;
if (this.mergeMode === "concat") {
outputShape[outputShape.length - 1] *= 2;
outputShapes = [outputShape];
} else if (this.mergeMode == null) {
outputShapes = [outputShape, outputShape.slice()];
} else {
outputShapes = [outputShape];
}
if (this.returnState) {
if (this.mergeMode == null) {
return outputShapes.concat(stateShape).concat(stateShape.slice());
}
return [outputShape].concat(stateShape).concat(stateShape.slice());
}
return singletonOrArray(outputShapes);
};
Bidirectional2.prototype.apply = function(inputs, kwargs) {
var initialState = kwargs == null ? null : kwargs["initialState"];
var constants = kwargs == null ? null : kwargs["constants"];
if (kwargs == null) {
kwargs = {};
}
var standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
inputs = standardized.inputs;
initialState = standardized.initialState;
constants = standardized.constants;
if (Array.isArray(inputs)) {
initialState = inputs.slice(1);
inputs = inputs[0];
}
if ((initialState == null || initialState.length === 0) && constants == null) {
return _super.prototype.apply.call(this, inputs, kwargs);
}
var additionalInputs = [];
var additionalSpecs = [];
if (initialState != null) {
var numStates = initialState.length;
if (numStates % 2 > 0) {
throw new ValueError("When passing `initialState` to a Bidrectional RNN, the state should be an Array containing the states of the underlying RNNs.");
}
kwargs["initialState"] = initialState;
additionalInputs.push.apply(additionalInputs, initialState);
var stateSpecs = initialState.map(function(state) {
return new InputSpec({shape: state.shape});
});
this.forwardLayer.stateSpec = stateSpecs.slice(0, numStates / 2);
this.backwardLayer.stateSpec = stateSpecs.slice(numStates / 2);
additionalSpecs.push.apply(additionalSpecs, stateSpecs);
}
if (constants != null) {
throw new NotImplementedError("Support for constants in Bidirectional layers is not implemented yet.");
}
var isSymbolicTensor = additionalInputs[0] instanceof SymbolicTensor;
for (var _i = 0, additionalInputs_1 = additionalInputs; _i < additionalInputs_1.length; _i++) {
var tensor = additionalInputs_1[_i];
if (tensor instanceof SymbolicTensor !== isSymbolicTensor) {
throw new ValueError("The initial state of a Bidirectional layer cannot be specified as a mix of symbolic and non-symbolic tensors");
}
}
if (isSymbolicTensor) {
var fullInput = [inputs].concat(additionalInputs);
var fullInputSpec = this.inputSpec.concat(additionalSpecs);
var originalInputSpec = this.inputSpec;
this.inputSpec = fullInputSpec;
var output = _super.prototype.apply.call(this, fullInput, kwargs);
this.inputSpec = originalInputSpec;
return output;
} else {
return _super.prototype.apply.call(this, inputs, kwargs);
}
};
Bidirectional2.prototype.call = function(inputs, kwargs) {
var _this = this;
return tfc.tidy(function() {
var initialState = kwargs["initialState"];
var y;
var yRev;
if (initialState == null) {
y = _this.forwardLayer.call(inputs, kwargs);
yRev = _this.backwardLayer.call(inputs, kwargs);
} else {
var forwardState = initialState.slice(0, initialState.length / 2);
var backwardState = initialState.slice(initialState.length / 2);
y = _this.forwardLayer.call(inputs, Object.assign(kwargs, {initialState: forwardState}));
yRev = _this.backwardLayer.call(inputs, Object.assign(kwargs, {initialState: backwardState}));
}
var states;
if (_this.returnState) {
if (Array.isArray(y)) {
states = y.slice(1).concat(yRev.slice(1));
}
y = y[0];
yRev = yRev[0];
}
if (_this.returnSequences) {
yRev = tfc.reverse(yRev, 1);
}
var output;
if (_this.mergeMode === "concat") {
output = concatenate([y, yRev]);
} else if (_this.mergeMode === "sum") {
output = tfc.add(y, yRev);
} else if (_this.mergeMode === "ave") {
output = tfc.mul(0.5, tfc.add(y, yRev));
} else if (_this.mergeMode === "mul") {
output = tfc.mul(y, yRev);
} else if (_this.mergeMode == null) {
output = [y, yRev];
}
if (_this.returnState) {
if (_this.mergeMode == null) {
return output.concat(states);
}
return [output].concat(states);
}
return output;
});
};
Bidirectional2.prototype.resetStates = function(states) {
this.forwardLayer.resetStates();
this.backwardLayer.resetStates();
};
Bidirectional2.prototype.build = function(inputShape) {
var _this = this;
nameScope(this.forwardLayer.name, function() {
_this.forwardLayer.build(inputShape);
});
nameScope(this.backwardLayer.name, function() {
_this.backwardLayer.build(inputShape);
});
this.built = true;
};
Bidirectional2.prototype.computeMask = function(inputs, mask) {
if (Array.isArray(mask)) {
mask = mask[0];
}
var outputMask;
if (this.returnSequences) {
if (this.mergeMode == null) {
outputMask = [mask, mask];
} else {
outputMask = mask;
}
} else {
if (this.mergeMode == null) {
outputMask = [null, null];
} else {
outputMask = null;
}
}
if (this.returnState) {
var states = this.forwardLayer.states;
var stateMask = states.map(function(state) {
return null;
});
if (Array.isArray(outputMask)) {
return outputMask.concat(stateMask).concat(stateMask);
} else {
return [outputMask].concat(stateMask).concat(stateMask);
}
} else {
return outputMask;
}
};
Object.defineProperty(Bidirectional2.prototype, "trainableWeights", {
get: function() {
return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights);
},
enumerable: true,
configurable: true
});
Object.defineProperty(Bidirectional2.prototype, "nonTrainableWeights", {
get: function() {
return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights);
},
enumerable: true,
configurable: true
});
Bidirectional2.prototype.setFastWeightInitDuringBuild = function(value) {
_super.prototype.setFastWeightInitDuringBuild.call(this, value);
if (this.forwardLayer != null) {
this.forwardLayer.setFastWeightInitDuringBuild(value);
}
if (this.backwardLayer != null) {
this.backwardLayer.setFastWeightInitDuringBuild(value);
}
};
Bidirectional2.prototype.getConfig = function() {
var config = {
mergeMode: this.mergeMode
};
var baseConfig = _super.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Bidirectional2.fromConfig = function(cls, config) {
var rnnLayer = deserialize(config["layer"]);
delete config["layer"];
if (config["numConstants"] != null) {
throw new NotImplementedError("Deserialization of a Bidirectional layer with numConstants present is not supported yet.");
}
var newConfig = config;
newConfig["layer"] = rnnLayer;
return new cls(newConfig);
};
Bidirectional2.className = "Bidirectional";
return Bidirectional2;
}(Wrapper);
tfc.serialization.registerClass(Bidirectional);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function inputLayer(args) {
return new InputLayer(args);
}
function elu$1(args) {
return new ELU(args);
}
function reLU(args) {
return new ReLU(args);
}
function leakyReLU(args) {
return new LeakyReLU(args);
}
function prelu(args) {
return new PReLU(args);
}
function softmax(args) {
return new Softmax$1(args);
}
function thresholdedReLU(args) {
return new ThresholdedReLU(args);
}
function conv1d(args) {
return new Conv1D(args);
}
function conv2d(args) {
return new Conv2D(args);
}
function conv2dTranspose(args) {
return new Conv2DTranspose(args);
}
function conv3d(args) {
return new Conv3D(args);
}
function separableConv2d(args) {
return new SeparableConv2D(args);
}
function cropping2D(args) {
return new Cropping2D(args);
}
function upSampling2d(args) {
return new UpSampling2D(args);
}
function depthwiseConv2d$1(args) {
return new DepthwiseConv2D(args);
}
function activation(args) {
return new Activation$1(args);
}
function dense(args) {
return new Dense(args);
}
function dropout$1(args) {
return new Dropout(args);
}
function spatialDropout1d(args) {
return new SpatialDropout1D(args);
}
function flatten$1(args) {
return new Flatten(args);
}
function repeatVector(args) {
return new RepeatVector(args);
}
function reshape(args) {
return new Reshape(args);
}
function permute(args) {
return new Permute(args);
}
function embedding(args) {
return new Embedding(args);
}
function add(args) {
return new Add(args);
}
function average(args) {
return new Average(args);
}
function concatenate$1(args) {
return new Concatenate(args);
}
function maximum(args) {
return new Maximum(args);
}
function minimum(args) {
return new Minimum(args);
}
function multiply(args) {
return new Multiply(args);
}
function dot$1(args) {
return new Dot(args);
}
function batchNormalization$1(args) {
return new BatchNormalization(args);
}
function layerNormalization(args) {
return new LayerNormalization(args);
}
function zeroPadding2d(args) {
return new ZeroPadding2D(args);
}
function averagePooling1d(args) {
return new AveragePooling1D(args);
}
function avgPool1d(args) {
return averagePooling1d(args);
}
function avgPooling1d(args) {
return averagePooling1d(args);
}
function averagePooling2d(args) {
return new AveragePooling2D(args);
}
function avgPool2d(args) {
return averagePooling2d(args);
}
function avgPooling2d(args) {
return averagePooling2d(args);
}
function averagePooling3d(args) {
return new AveragePooling3D(args);
}
function avgPool3d(args) {
return averagePooling3d(args);
}
function avgPooling3d(args) {
return averagePooling3d(args);
}
function globalAveragePooling1d(args) {
return new GlobalAveragePooling1D(args);
}
function globalAveragePooling2d(args) {
return new GlobalAveragePooling2D(args);
}
function globalMaxPooling1d(args) {
return new GlobalMaxPooling1D(args);
}
function globalMaxPooling2d(args) {
return new GlobalMaxPooling2D(args);
}
function maxPooling1d(args) {
return new MaxPooling1D(args);
}
function maxPooling2d(args) {
return new MaxPooling2D(args);
}
function maxPooling3d(args) {
return new MaxPooling3D(args);
}
function gru(args) {
return new GRU(args);
}
function gruCell(args) {
return new GRUCell(args);
}
function lstm(args) {
return new LSTM(args);
}
function lstmCell(args) {
return new LSTMCell(args);
}
function simpleRNN(args) {
return new SimpleRNN(args);
}
function simpleRNNCell(args) {
return new SimpleRNNCell(args);
}
function convLstm2d(args) {
return new ConvLSTM2D(args);
}
function convLstm2dCell(args) {
return new ConvLSTM2DCell(args);
}
function rnn$1(args) {
return new RNN(args);
}
function stackedRNNCells(args) {
return new StackedRNNCells(args);
}
function bidirectional(args) {
return new Bidirectional(args);
}
function timeDistributed(args) {
return new TimeDistributed(args);
}
var globalMaxPool1d = globalMaxPooling1d;
var globalMaxPool2d = globalMaxPooling2d;
var maxPool1d = maxPooling1d;
var maxPool2d = maxPooling2d;
function gaussianNoise(args) {
return new GaussianNoise(args);
}
function gaussianDropout(args) {
return new GaussianDropout(args);
}
function alphaDropout(args) {
return new AlphaDropout(args);
}
function masking(args) {
return new Masking(args);
}
var exports_layers = {
__proto__: null,
inputLayer,
elu: elu$1,
reLU,
leakyReLU,
prelu,
softmax,
thresholdedReLU,
conv1d,
conv2d,
conv2dTranspose,
conv3d,
separableConv2d,
cropping2D,
upSampling2d,
depthwiseConv2d: depthwiseConv2d$1,
activation,
dense,
dropout: dropout$1,
spatialDropout1d,
flatten: flatten$1,
repeatVector,
reshape,
permute,
embedding,
add,
average,
concatenate: concatenate$1,
maximum,
minimum,
multiply,
dot: dot$1,
batchNormalization: batchNormalization$1,
layerNormalization,
zeroPadding2d,
averagePooling1d,
avgPool1d,
avgPooling1d,
averagePooling2d,
avgPool2d,
avgPooling2d,
averagePooling3d,
avgPool3d,
avgPooling3d,
globalAveragePooling1d,
globalAveragePooling2d,
globalMaxPooling1d,
globalMaxPooling2d,
maxPooling1d,
maxPooling2d,
maxPooling3d,
gru,
gruCell,
lstm,
lstmCell,
simpleRNN,
simpleRNNCell,
convLstm2d,
convLstm2dCell,
rnn: rnn$1,
stackedRNNCells,
bidirectional,
timeDistributed,
globalMaxPool1d,
globalMaxPool2d,
maxPool1d,
maxPool2d,
Layer,
RNN,
RNNCell,
input,
gaussianNoise,
gaussianDropout,
alphaDropout,
masking
};
function binaryAccuracy$1(yTrue, yPred) {
return binaryAccuracy(yTrue, yPred);
}
function binaryCrossentropy$2(yTrue, yPred) {
return binaryCrossentropy$1(yTrue, yPred);
}
function sparseCategoricalAccuracy$1(yTrue, yPred) {
return sparseCategoricalAccuracy(yTrue, yPred);
}
function categoricalAccuracy$1(yTrue, yPred) {
return categoricalAccuracy(yTrue, yPred);
}
function categoricalCrossentropy$2(yTrue, yPred) {
return categoricalCrossentropy$1(yTrue, yPred);
}
function precision$1(yTrue, yPred) {
return precision(yTrue, yPred);
}
function recall$1(yTrue, yPred) {
return recall(yTrue, yPred);
}
function cosineProximity$1(yTrue, yPred) {
return cosineProximity(yTrue, yPred);
}
function meanAbsoluteError$1(yTrue, yPred) {
return meanAbsoluteError(yTrue, yPred);
}
function meanAbsolutePercentageError$1(yTrue, yPred) {
return meanAbsolutePercentageError(yTrue, yPred);
}
function MAPE$1(yTrue, yPred) {
return meanAbsolutePercentageError(yTrue, yPred);
}
function mape$1(yTrue, yPred) {
return meanAbsolutePercentageError(yTrue, yPred);
}
function meanSquaredError$1(yTrue, yPred) {
return meanSquaredError(yTrue, yPred);
}
function MSE$1(yTrue, yPred) {
return meanSquaredError(yTrue, yPred);
}
function mse$1(yTrue, yPred) {
return meanSquaredError(yTrue, yPred);
}
var exports_metrics = {
__proto__: null,
binaryAccuracy: binaryAccuracy$1,
binaryCrossentropy: binaryCrossentropy$2,
sparseCategoricalAccuracy: sparseCategoricalAccuracy$1,
categoricalAccuracy: categoricalAccuracy$1,
categoricalCrossentropy: categoricalCrossentropy$2,
precision: precision$1,
recall: recall$1,
cosineProximity: cosineProximity$1,
meanAbsoluteError: meanAbsoluteError$1,
meanAbsolutePercentageError: meanAbsolutePercentageError$1,
MAPE: MAPE$1,
mape: mape$1,
meanSquaredError: meanSquaredError$1,
MSE: MSE$1,
mse: mse$1
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var exports_models = {
__proto__: null,
modelFromJSON
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function l1l2(config) {
return new L1L2(config);
}
function l1$1(config) {
return l1(config);
}
function l2$1(config) {
return l2(config);
}
var exports_regularizers = {
__proto__: null,
l1l2,
l1: l1$1,
l2: l2$1
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var Callback = function(_super) {
__extends(Callback2, _super);
function Callback2() {
var _this = _super !== null && _super.apply(this, arguments) || this;
_this.model = null;
return _this;
}
Callback2.prototype.setModel = function(model2) {
if (!(model2 instanceof LayersModel)) {
throw new Error("model must be a LayersModel, not some other Container");
}
this.model = model2;
};
return Callback2;
}(BaseCallback);
function less(currVal, prevVal) {
return currVal < prevVal;
}
function greater(currVal, prevVal) {
return currVal > prevVal;
}
var EarlyStopping = function(_super) {
__extends(EarlyStopping2, _super);
function EarlyStopping2(args) {
var _this = _super.call(this) || this;
if (args == null) {
args = {};
}
if (args.restoreBestWeights) {
throw new NotImplementedError("restoreBestWeights = True is not implemented in EarlyStopping yet.");
}
_this.monitor = args.monitor || "val_loss";
_this.minDelta = Math.abs(args.minDelta || 0);
_this.patience = args.patience || 0;
_this.verbose = args.verbose || 0;
_this.mode = args.mode || "auto";
_this.baseline = args.baseline;
if (["auto", "min", "max"].indexOf(_this.mode) === -1) {
console.warn("EarlyStopping mode '" + _this.mode + "' is invalid. Falling back to mode 'auto'.");
_this.mode = "auto";
}
if (_this.mode === "min") {
_this.monitorFunc = less;
} else if (_this.mode === "max") {
_this.monitorFunc = greater;
} else {
if (_this.monitor.indexOf("acc") !== -1) {
_this.monitorFunc = greater;
} else {
_this.monitorFunc = less;
}
}
if (_this.monitorFunc === less) {
_this.minDelta *= -1;
}
return _this;
}
EarlyStopping2.prototype.onTrainBegin = function(logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
this.wait = 0;
this.stoppedEpoch = 0;
if (this.baseline != null) {
this.best = this.baseline;
} else {
this.best = this.monitorFunc === less ? Infinity : -Infinity;
}
return [2];
});
});
};
EarlyStopping2.prototype.onEpochEnd = function(epoch, logs) {
return __awaiter(this, void 0, void 0, function() {
var current;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, resolveScalarsInLogs(logs)];
case 1:
_a.sent();
current = this.getMonitorValue(logs);
if (current == null) {
return [2];
}
if (this.monitorFunc(current - this.minDelta, this.best)) {
this.best = current;
this.wait = 0;
} else {
this.wait++;
if (this.wait >= this.patience) {
this.stoppedEpoch = epoch;
this.model.stopTraining = true;
}
}
return [2];
}
});
});
};
EarlyStopping2.prototype.onTrainEnd = function(logs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
if (this.stoppedEpoch > 0 && this.verbose) {
console.log("Epoch " + this.stoppedEpoch + ": early stopping.");
}
return [2];
});
});
};
EarlyStopping2.prototype.getMonitorValue = function(logs) {
if (logs == null) {
logs = {};
}
var monitorValue = logs[this.monitor];
if (monitorValue == null) {
console.warn("Metric for EarlyStopping " + this.monitor + " is not available. " + ("Available metrics are: " + Object.keys(logs)));
}
return monitorValue;
};
return EarlyStopping2;
}(Callback);
function earlyStopping(args) {
return new EarlyStopping(args);
}
var callbacks = {earlyStopping};
exports.Callback = Callback;
exports.CallbackList = CallbackList;
exports.CustomCallback = CustomCallback;
exports.EarlyStopping = EarlyStopping;
exports.History = History;
exports.InputSpec = InputSpec;
exports.LayerVariable = LayerVariable;
exports.LayersModel = LayersModel;
exports.RNN = RNN;
exports.Sequential = Sequential;
exports.SymbolicTensor = SymbolicTensor;
exports.callbacks = callbacks;
exports.constraints = exports_constraints;
exports.initializers = exports_initializers;
exports.input = input;
exports.layers = exports_layers;
exports.loadLayersModel = loadLayersModel;
exports.metrics = exports_metrics;
exports.model = model;
exports.models = exports_models;
exports.registerCallbackConstructor = registerCallbackConstructor;
exports.regularizers = exports_regularizers;
exports.sequential = sequential;
exports.version_layers = version;
});
// node_modules/@tensorflow/tfjs-converter/dist/tf-converter.node.js
var require_tf_converter_node = __commonJS((exports) => {
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
"use strict";
Object.defineProperty(exports, "__esModule", {value: true});
var tfOps = require_tf_core_node();
/*! *****************************************************************************
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at http://www.apache.org/licenses/LICENSE-2.0
THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
MERCHANTABLITY OR NON-INFRINGEMENT.
See the Apache Version 2.0 License for specific language governing permissions
and limitations under the License.
***************************************************************************** */
var __assign = Object.assign || function __assign2(t) {
for (var s, i = 1, n = arguments.length; i < n; i++) {
s = arguments[i];
for (var p in s)
if (Object.prototype.hasOwnProperty.call(s, p))
t[p] = s[p];
}
return t;
};
function __awaiter(thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function(resolve, reject) {
function fulfilled(value) {
try {
step(generator.next(value));
} catch (e) {
reject(e);
}
}
function rejected(value) {
try {
step(generator["throw"](value));
} catch (e) {
reject(e);
}
}
function step(result) {
result.done ? resolve(result.value) : new P(function(resolve2) {
resolve2(result.value);
}).then(fulfilled, rejected);
}
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = {label: 0, sent: function() {
if (t[0] & 1)
throw t[1];
return t[1];
}, trys: [], ops: []}, f, y, t, g;
return g = {next: verb(0), throw: verb(1), return: verb(2)}, typeof Symbol === "function" && (g[Symbol.iterator] = function() {
return this;
}), g;
function verb(n) {
return function(v) {
return step([n, v]);
};
}
function step(op) {
if (f)
throw new TypeError("Generator is already executing.");
while (_)
try {
if (f = 1, y && (t = y[op[0] & 2 ? "return" : op[0] ? "throw" : "next"]) && !(t = t.call(y, op[1])).done)
return t;
if (y = 0, t)
op = [0, t.value];
switch (op[0]) {
case 0:
case 1:
t = op;
break;
case 4:
_.label++;
return {value: op[1], done: false};
case 5:
_.label++;
y = op[1];
op = [0];
continue;
case 7:
op = _.ops.pop();
_.trys.pop();
continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) {
_ = 0;
continue;
}
if (op[0] === 3 && (!t || op[1] > t[0] && op[1] < t[3])) {
_.label = op[1];
break;
}
if (op[0] === 6 && _.label < t[1]) {
_.label = t[1];
t = op;
break;
}
if (t && _.label < t[2]) {
_.label = t[2];
_.ops.push(op);
break;
}
if (t[2])
_.ops.pop();
_.trys.pop();
continue;
}
op = body.call(thisArg, _);
} catch (e) {
op = [6, e];
y = 0;
} finally {
f = t = 0;
}
if (op[0] & 5)
throw op[1];
return {value: op[0] ? op[1] : void 0, done: true};
}
}
function __read(o, n) {
var m = typeof Symbol === "function" && o[Symbol.iterator];
if (!m)
return o;
var i = m.call(o), r, ar = [], e;
try {
while ((n === void 0 || n-- > 0) && !(r = i.next()).done)
ar.push(r.value);
} catch (error) {
e = {error};
} finally {
try {
if (r && !r.done && (m = i["return"]))
m.call(i);
} finally {
if (e)
throw e.error;
}
}
return ar;
}
function __spread() {
for (var ar = [], i = 0; i < arguments.length; i++)
ar = ar.concat(__read(arguments[i]));
return ar;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var DataType;
(function(DataType2) {
DataType2[DataType2["DT_INVALID"] = 0] = "DT_INVALID";
DataType2[DataType2["DT_FLOAT"] = 1] = "DT_FLOAT";
DataType2[DataType2["DT_DOUBLE"] = 2] = "DT_DOUBLE";
DataType2[DataType2["DT_INT32"] = 3] = "DT_INT32";
DataType2[DataType2["DT_UINT8"] = 4] = "DT_UINT8";
DataType2[DataType2["DT_INT16"] = 5] = "DT_INT16";
DataType2[DataType2["DT_INT8"] = 6] = "DT_INT8";
DataType2[DataType2["DT_STRING"] = 7] = "DT_STRING";
DataType2[DataType2["DT_COMPLEX64"] = 8] = "DT_COMPLEX64";
DataType2[DataType2["DT_INT64"] = 9] = "DT_INT64";
DataType2[DataType2["DT_BOOL"] = 10] = "DT_BOOL";
DataType2[DataType2["DT_QINT8"] = 11] = "DT_QINT8";
DataType2[DataType2["DT_QUINT8"] = 12] = "DT_QUINT8";
DataType2[DataType2["DT_QINT32"] = 13] = "DT_QINT32";
DataType2[DataType2["DT_BFLOAT16"] = 14] = "DT_BFLOAT16";
DataType2[DataType2["DT_FLOAT_REF"] = 101] = "DT_FLOAT_REF";
DataType2[DataType2["DT_DOUBLE_REF"] = 102] = "DT_DOUBLE_REF";
DataType2[DataType2["DT_INT32_REF"] = 103] = "DT_INT32_REF";
DataType2[DataType2["DT_UINT8_REF"] = 104] = "DT_UINT8_REF";
DataType2[DataType2["DT_INT16_REF"] = 105] = "DT_INT16_REF";
DataType2[DataType2["DT_INT8_REF"] = 106] = "DT_INT8_REF";
DataType2[DataType2["DT_STRING_REF"] = 107] = "DT_STRING_REF";
DataType2[DataType2["DT_COMPLEX64_REF"] = 108] = "DT_COMPLEX64_REF";
DataType2[DataType2["DT_INT64_REF"] = 109] = "DT_INT64_REF";
DataType2[DataType2["DT_BOOL_REF"] = 110] = "DT_BOOL_REF";
DataType2[DataType2["DT_QINT8_REF"] = 111] = "DT_QINT8_REF";
DataType2[DataType2["DT_QUINT8_REF"] = 112] = "DT_QUINT8_REF";
DataType2[DataType2["DT_QINT32_REF"] = 113] = "DT_QINT32_REF";
DataType2[DataType2["DT_BFLOAT16_REF"] = 114] = "DT_BFLOAT16_REF";
})(DataType || (DataType = {}));
var SaverDef;
(function(SaverDef2) {
var CheckpointFormatVersion;
(function(CheckpointFormatVersion2) {
CheckpointFormatVersion2[CheckpointFormatVersion2["LEGACY"] = 0] = "LEGACY";
CheckpointFormatVersion2[CheckpointFormatVersion2["V1"] = 1] = "V1";
CheckpointFormatVersion2[CheckpointFormatVersion2["V2"] = 2] = "V2";
})(CheckpointFormatVersion = SaverDef2.CheckpointFormatVersion || (SaverDef2.CheckpointFormatVersion = {}));
})(SaverDef || (SaverDef = {}));
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CUSTOM_OPS = {};
function registerOp(name, opFunc) {
var opMapper = {
tfOpName: name,
category: "custom",
inputs: [],
attrs: [],
customExecutor: opFunc
};
CUSTOM_OPS[name] = opMapper;
}
function getRegisteredOp(name) {
return CUSTOM_OPS[name];
}
function deregisterOp(name) {
delete CUSTOM_OPS[name];
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getParamValue(paramName, node, tensorMap, context, resourceManager) {
var inputParam = node.inputParams[paramName];
if (inputParam && inputParam.inputIndexStart !== void 0) {
var start = inputParam.inputIndexStart;
var end = inputParam.inputIndexEnd === 0 ? void 0 : inputParam.inputIndexEnd === void 0 ? start + 1 : inputParam.inputIndexEnd;
if (inputParam.type === "tensor") {
return getTensor(node.inputNames[inputParam.inputIndexStart], tensorMap, context, resourceManager);
}
if (inputParam.type === "tensors") {
var inputs = node.inputNames.slice(start, end);
return inputs.map(function(name) {
return getTensor(name, tensorMap, context, resourceManager);
});
}
var tensor = getTensor(node.inputNames.slice(start)[0], tensorMap, context, resourceManager);
var data = tensor.dataSync();
return inputParam.type === "number" ? data[0] : tfOps.util.toNestedArray(tensor.shape, data);
}
var attrParam = node.attrParams[paramName];
return attrParam && attrParam.value;
}
function getTensor(name, tensorsMap, context, resourceManager) {
var _a = __read(parseNodeName(name), 2), nodeName = _a[0], index = _a[1];
if (resourceManager != null) {
var tensor = resourceManager.getHashTableHandleByName(nodeName);
if (tensor != null) {
return tensor;
}
}
var contextId = context.currentContextIds.find(function(contextId2) {
return !!tensorsMap[getNodeNameWithContextId(nodeName, contextId2)];
});
return contextId !== void 0 ? tensorsMap[getNodeNameWithContextId(nodeName, contextId)][index] : void 0;
}
function getTensorsForCurrentContenxt(name, tensorsMap, context) {
return tensorsMap[getNodeNameWithContextId(name, context.currentContextId)];
}
function getNodeNameAndIndex(inputName, context) {
var _a = __read(parseNodeName(inputName), 2), nodeName = _a[0], index = _a[1];
return [
getNodeNameWithContextId(nodeName, context && context.currentContextId),
index
];
}
function getNodeNameWithContextId(name, contextId) {
return !!contextId ? name + "-" + contextId : name;
}
function parseNodeName(name) {
var parts = name.split(":");
if (parts.length === 1) {
return [name, 0];
}
var nodeName = parts[0];
return [nodeName, Number(parts[parts.length - 1])];
}
function getPadding(node, tensorMap, context) {
var pad = getParamValue("pad", node, tensorMap, context);
if (pad === "explicit") {
pad = getParamValue("explicitPaddings", node, tensorMap, context);
var explicitPadding = [[0, 0], [0, 0], [0, 0], [0, 0]];
for (var i = 0; i < 4; i++) {
explicitPadding[i][0] = pad[i * 2];
explicitPadding[i][1] = pad[i * 2 + 1];
}
return explicitPadding;
}
return pad;
}
function cloneTensor(tensor) {
return tensor.kept ? tensor : tfOps.clone(tensor);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json = [
{
tfOpName: "Add",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "AddV2",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "AddN",
category: "arithmetic",
inputs: [{start: 0, end: 0, name: "tensors", type: "tensors"}]
},
{
tfOpName: "BiasAdd",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Sub",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "RealDiv",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Div",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "DivNoNan",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "FloorDiv",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Mul",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Maximum",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
]
},
{
tfOpName: "Minimum",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
]
},
{
tfOpName: "Pow",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "SquaredDifference",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Mod",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "FloorMod",
category: "arithmetic",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
}
];
var arithmetic = {
__proto__: null,
json
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$1 = [
{
tfOpName: "Abs",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Acos",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Asin",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Atan",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Atan2",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "y", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Ceil",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "ClipByValue",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "clip_value_min", name: "clipValueMin", type: "number"},
{tfName: "clip_value_max", name: "clipValueMax", type: "number"}
]
},
{
tfOpName: "Complex",
category: "basic_math",
inputs: [
{start: 0, name: "real", type: "tensor"},
{start: 1, name: "imag", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "ComplexAbs",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Cos",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Cosh",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Elu",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Exp",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Floor",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Log",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Imag",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{
tfName: "Tout",
name: "outputType",
type: "dtype",
notSupported: true
}
]
},
{
tfOpName: "Neg",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Real",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{
tfName: "Tout",
name: "outputType",
type: "dtype",
notSupported: true
}
]
},
{
tfOpName: "Prelu",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "alpha", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Relu",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Relu6",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{
tfName: "clipValueMin",
name: "clipValueMin",
type: "number",
defaultValue: 0
},
{
tfName: "clipValueMax",
name: "clipValueMax",
type: "number",
defaultValue: 6
}
]
},
{
tfOpName: "Selu",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Sigmoid",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Sin",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Sinh",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Sqrt",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Rsqrt",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Square",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Tan",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Tanh",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Sign",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Round",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Expm1",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Log1p",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Reciprocal",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Softplus",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Asinh",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Acosh",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Atanh",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Erf",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Prod",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axes", type: "number[]"}
],
attrs: [
{
tfName: "keep_dims",
name: "keepDims",
type: "bool",
notSupported: true
},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "LeakyRelu",
category: "basic_math",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{
tfName: "alpha",
name: "alpha",
type: "number",
defaultValue: 0.2
},
{
tfName: "T",
name: "dtype",
type: "dtype",
notSupported: true
}
]
}
];
var basicMath = {
__proto__: null,
json: json$1
};
var json$2 = [
{
tfOpName: "LoopCond",
category: "control",
inputs: [{start: 0, name: "pred", type: "tensor"}]
},
{
tfOpName: "Switch",
category: "control",
inputs: [
{start: 0, name: "data", type: "tensor"},
{start: 1, name: "pred", type: "tensor"}
]
},
{
tfOpName: "Merge",
category: "control",
inputs: [{start: 0, end: 0, name: "tensors", type: "tensors"}]
},
{
tfOpName: "Enter",
category: "control",
inputs: [
{start: 0, name: "tensor", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{tfName: "frame_name", name: "frameName", type: "string"},
{tfName: "is_constant", name: "isConstant", type: "bool"}
]
},
{
tfOpName: "Exit",
category: "control",
inputs: [
{start: 0, name: "tensor", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "NextIteration",
category: "control",
inputs: [
{start: 0, name: "tensor", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "TensorArrayV3",
category: "control",
inputs: [
{start: 0, name: "size", type: "number"}
],
attrs: [
{tfName: "dtype", name: "dtype", type: "dtype"},
{tfName: "element_shape", name: "elementShape", type: "shape"},
{tfName: "dynamic_size", name: "dynamicSize", type: "bool"},
{tfName: "clear_after_read", name: "clearAfterRead", type: "bool"},
{
tfName: "identical_element_shapes",
name: "identicalElementShapes",
type: "bool"
},
{tfName: "tensor_array_name", name: "name", type: "string"}
]
},
{
tfOpName: "TensorArrayWriteV3",
category: "control",
inputs: [
{start: 0, name: "tensorArrayId", type: "tensor"},
{start: 1, name: "index", type: "number"},
{start: 2, name: "tensor", type: "tensor"},
{start: 3, name: "flowIn", type: "number"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "TensorArrayReadV3",
category: "control",
inputs: [
{start: 0, name: "tensorArrayId", type: "tensor"},
{start: 1, name: "index", type: "number"},
{start: 2, name: "flowIn", type: "number"}
],
attrs: [{
tfName: "dtype",
name: "dtype",
type: "dtype",
notSupported: true
}]
},
{
tfOpName: "TensorArrayGatherV3",
category: "control",
inputs: [
{start: 0, name: "tensorArrayId", type: "tensor"},
{start: 1, name: "indices", type: "number[]"},
{start: 2, name: "flowIn", type: "number"}
],
attrs: [
{tfName: "dtype", name: "dtype", type: "dtype"},
{tfName: "element_shape", name: "elementShape", type: "shape"}
]
},
{
tfOpName: "TensorArrayScatterV3",
category: "control",
inputs: [
{start: 0, name: "tensorArrayId", type: "tensor"},
{start: 1, name: "indices", type: "number[]"},
{start: 2, name: "tensor", type: "tensor"},
{start: 3, name: "flowIn", type: "number"}
],
attrs: [{tfName: "T", name: "dtype", type: "dtype"}]
},
{
tfOpName: "TensorArrayConcatV3",
category: "control",
inputs: [
{start: 0, name: "tensorArrayId", type: "tensor"},
{start: 1, name: "flowIn", type: "number"}
],
attrs: [
{tfName: "dtype", name: "dtype", type: "dtype"},
{
tfName: "element_shape_except0",
name: "elementShapeExcept0",
type: "shape",
notSupported: true
}
]
},
{
tfOpName: "TensorArraySplitV3",
category: "control",
inputs: [
{start: 0, name: "tensorArrayId", type: "tensor"},
{start: 1, name: "tensor", type: "tensor"},
{start: 2, name: "lengths", type: "number[]"},
{start: 3, name: "flowIn", type: "number"}
],
attrs: [{tfName: "T", name: "dtype", type: "dtype"}]
},
{
tfOpName: "TensorArraySizeV3",
category: "control",
inputs: [
{start: 0, name: "tensorArrayId", type: "tensor"},
{start: 1, name: "flowIn", type: "number"}
]
},
{
tfOpName: "TensorArrayCloseV3",
category: "control",
inputs: [{start: 0, name: "tensorArrayId", type: "tensor"}]
},
{
tfOpName: "StatelessIf",
category: "control",
inputs: [
{start: 0, name: "cond", type: "tensor"},
{start: 1, end: 0, name: "args", type: "tensors"}
],
attrs: [
{tfName: "then_branch", name: "thenBranch", type: "func"},
{tfName: "else_branch", name: "elseBranch", type: "func"}
]
},
{
tfOpName: "If",
category: "control",
inputs: [
{start: 0, name: "cond", type: "tensor"},
{start: 1, end: 0, name: "args", type: "tensors"}
],
attrs: [
{tfName: "then_branch", name: "thenBranch", type: "func"},
{tfName: "else_branch", name: "elseBranch", type: "func"}
]
},
{
tfOpName: "StatelessWhile",
category: "control",
inputs: [
{start: 0, end: 0, name: "args", type: "tensors"}
],
attrs: [
{tfName: "cond", name: "cond", type: "func"},
{tfName: "body", name: "body", type: "func"}
]
},
{
tfOpName: "While",
category: "control",
inputs: [
{start: 0, end: 0, name: "args", type: "tensors"}
],
attrs: [
{tfName: "cond", name: "cond", type: "func"},
{tfName: "body", name: "body", type: "func"}
]
},
{
tfOpName: "TensorListScatter",
category: "control",
inputs: [
{start: 0, name: "tensor", type: "tensor"},
{start: 1, name: "indices", type: "number[]"},
{start: 2, name: "elementShape", type: "shape"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListScatterV2",
category: "control",
inputs: [
{start: 0, name: "tensor", type: "tensor"},
{start: 1, name: "indices", type: "number[]"},
{start: 2, name: "elementShape", type: "shape"},
{start: 3, name: "numElements", type: "number"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListGather",
category: "control",
inputs: [
{start: 0, name: "tensorListId", type: "tensor"},
{start: 1, name: "indices", type: "number[]"},
{start: 2, name: "elementShape", type: "shape"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListGetItem",
category: "control",
inputs: [
{start: 0, name: "tensorListId", type: "tensor"},
{start: 1, name: "index", type: "number"},
{start: 2, name: "elementShape", type: "shape"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListSetItem",
category: "control",
inputs: [
{start: 0, name: "tensorListId", type: "tensor"},
{start: 1, name: "index", type: "number"},
{start: 2, name: "tensor", type: "tensor"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListReserve",
category: "control",
inputs: [
{start: 0, name: "elementShape", type: "shape"},
{start: 1, name: "numElements", type: "number"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListFromTensor",
category: "control",
inputs: [
{start: 0, name: "tensor", type: "tensor"},
{start: 1, name: "elementShape", type: "shape"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListStack",
category: "control",
inputs: [
{start: 0, name: "tensorListId", type: "tensor"},
{start: 1, name: "elementShape", type: "shape"}
],
attrs: [
{tfName: "element_dtype", name: "elementDType", type: "dtype"},
{tfName: "num_elements", name: "numElements", type: "dtype"}
]
},
{
tfOpName: "TensorListSplit",
category: "control",
inputs: [
{start: 0, name: "tensor", type: "tensor"},
{start: 1, name: "elementShape", type: "shape"},
{start: 2, name: "lengths", type: "number[]"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListConcat",
category: "control",
inputs: [
{start: 0, name: "tensorListId", type: "tensor"}
],
attrs: [
{tfName: "element_shape", name: "elementShape", type: "shape"},
{tfName: "element_dtype", name: "elementDType", type: "dtype"}
]
},
{
tfOpName: "TensorListPopBack",
category: "control",
inputs: [
{start: 0, name: "tensorListId", type: "tensor"},
{start: 1, name: "elementShape", type: "shape"}
],
attrs: [{tfName: "element_dtype", name: "elementDType", type: "dtype"}]
},
{
tfOpName: "TensorListPushBack",
category: "control",
inputs: [
{start: 0, name: "tensorListId", type: "tensor"},
{start: 1, name: "tensor", type: "tensor"}
],
attrs: [
{tfName: "element_dtype", name: "elementDType", type: "dtype"}
]
}
];
var control = {
__proto__: null,
json: json$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$3 = [
{
tfOpName: "AvgPool",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
notSupported: true
},
{tfName: "ksize", name: "kernelSize", type: "number[]"},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "MaxPool",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
notSupported: true
},
{tfName: "ksize", name: "kernelSize", type: "number[]"},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "MaxPoolWithArgmax",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{tfName: "ksize", name: "kernelSize", type: "number[]"},
{
tfName: "include_batch_in_index",
name: "includeBatchInIndex",
type: "bool"
},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "AvgPool3D",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
notSupported: true
},
{tfName: "ksize", name: "kernelSize", type: "number[]"},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "MaxPool3D",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
notSupported: true
},
{tfName: "ksize", name: "kernelSize", type: "number[]"},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Conv1D",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "filter", type: "tensor"}
],
attrs: [
{tfName: "stride", name: "stride", type: "number"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
defaultValue: "NWC"
},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{
tfName: "dilation",
name: "dilation",
type: "number",
defaultValue: 1
}
]
},
{
tfOpName: "Conv2D",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "filter", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{tfName: "useCudnnOnGpu", name: "useCudnnOnGpu", type: "bool"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
defaultValue: "NHWC"
},
{
tfName: "explicit_paddings",
name: "explicitPaddings",
type: "number[]",
defaultValue: []
},
{tfName: "dilations", name: "dilations", type: "number[]"}
]
},
{
tfOpName: "_FusedConv2D",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "filter", type: "tensor"},
{start: 2, end: 0, name: "args", type: "tensors"}
],
attrs: [
{tfName: "num_args", name: "numArgs", type: "number"},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "explicit_paddings",
name: "explicitPaddings",
type: "number[]",
defaultValue: []
},
{
tfName: "use_cudnn_on_gpu",
name: "useCudnnOnGpu",
type: "bool",
defaultValue: true
},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
defaultValue: "NHWC"
},
{
tfName: "dilations",
name: "dilations",
type: "number[]",
defaultValue: [1, 1, 1, 1]
},
{
tfName: "fused_ops",
name: "fusedOps",
type: "string[]",
defaultValue: []
},
{
tfName: "epsilon",
name: "epsilon",
type: "number",
defaultValue: 1e-4
}
]
},
{
tfOpName: "Conv2DBackpropInput",
category: "convolution",
inputs: [
{start: 2, name: "x", type: "tensor"},
{start: 1, name: "filter", type: "tensor"},
{start: 0, name: "outputShape", type: "number[]"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
notSupported: true
},
{
tfName: "explicit_paddings",
name: "explicitPaddings",
type: "number[]",
defaultValue: []
}
]
},
{
tfOpName: "DepthwiseConv2d",
category: "convolution",
inputs: [
{start: 0, name: "input", type: "tensor"},
{start: 1, name: "filter", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
defaultValue: "NHWC"
},
{
tfName: "explicit_paddings",
name: "explicitPaddings",
type: "number[]",
defaultValue: []
},
{tfName: "dilations", name: "dilations", type: "number[]"}
]
},
{
tfOpName: "DepthwiseConv2dNative",
category: "convolution",
inputs: [
{start: 0, name: "input", type: "tensor"},
{start: 1, name: "filter", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
defaultValue: "NHWC"
},
{
tfName: "explicit_paddings",
name: "explicitPaddings",
type: "number[]",
defaultValue: []
},
{tfName: "dilations", name: "dilations", type: "number[]"}
]
},
{
tfOpName: "FusedDepthwiseConv2dNative",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "filter", type: "tensor"},
{start: 2, end: 0, name: "args", type: "tensors"}
],
attrs: [
{tfName: "num_args", name: "numArgs", type: "number"},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
defaultValue: "NHWC"
},
{
tfName: "dilations",
name: "dilations",
type: "number[]",
defaultValue: [1, 1, 1, 1]
},
{
tfName: "fused_ops",
name: "fusedOps",
type: "string[]",
defaultValue: []
}
]
},
{
tfOpName: "Conv3D",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "filter", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
defaultValue: "NHWC"
},
{tfName: "dilations", name: "dilations", type: "number[]"}
]
},
{
tfOpName: "Dilation2D",
category: "convolution",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "filter", type: "tensor"}
],
attrs: [
{tfName: "strides", name: "strides", type: "number[]"},
{tfName: "rates", name: "dilations", type: "number[]"},
{tfName: "padding", name: "pad", type: "string"}
]
}
];
var convolution = {
__proto__: null,
json: json$3
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$4 = [
{
tfOpName: "Fill",
category: "creation",
inputs: [
{start: 0, name: "shape", type: "number[]"},
{start: 1, name: "value", type: "number"}
],
attrs: [{tfName: "T", name: "dtype", type: "dtype"}]
},
{
tfOpName: "LinSpace",
category: "creation",
inputs: [
{start: 0, name: "start", type: "number"},
{start: 1, name: "stop", type: "number"},
{start: 2, name: "num", type: "number"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "OneHot",
category: "creation",
inputs: [
{start: 0, name: "indices", type: "tensor"},
{start: 1, name: "depth", type: "number"},
{start: 2, name: "onValue", type: "number", defaultValue: 1},
{start: 3, name: "offValue", type: "number", defaultValue: 0}
],
attrs: [
{
tfName: "axis",
name: "axis",
type: "number",
notSupported: true
},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Ones",
category: "creation",
inputs: [
{start: 0, name: "shape", type: "number[]"}
],
attrs: [{tfName: "T", name: "dtype", type: "dtype"}]
},
{
tfOpName: "OnesLike",
category: "creation",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [{tfName: "dtype", name: "dtype", type: "dtype"}]
},
{
tfOpName: "RandomUniform",
category: "creation",
inputs: [
{start: 0, name: "shape", type: "number[]"}
],
attrs: [
{
tfName: "minval",
name: "minval",
type: "number",
defaultValue: 0
},
{
tfName: "maxval",
name: "maxval",
type: "number",
defaultValue: 1
},
{tfName: "dtype", name: "dtype", type: "dtype"},
{tfName: "seed", name: "seed", type: "number", defaultValue: 0},
{
tfName: "seed2",
name: "seed2",
type: "number",
defaultValue: 0,
notSupported: true
},
{tfName: "T", name: "T", type: "number", notSupported: true}
]
},
{
tfOpName: "Range",
category: "creation",
inputs: [
{start: 0, name: "start", type: "number"},
{start: 1, name: "stop", type: "number"},
{start: 2, name: "step", type: "number", defaultValue: 0}
],
attrs: [{tfName: "Tidx", name: "dtype", type: "dtype"}]
},
{
tfOpName: "TruncatedNormal",
category: "creation",
inputs: [
{start: 0, name: "shape", type: "number[]"}
],
attrs: [
{
tfName: "means",
name: "mean",
type: "number",
defaultValue: 0
},
{
tfName: "stddev",
name: "stdDev",
type: "number",
defaultValue: 1
},
{tfName: "seed", name: "seed", type: "number"},
{
tfName: "seed2",
name: "seed2",
type: "number",
defaultValue: 0,
notSupported: true
},
{tfName: "dtype", name: "dtype", type: "dtype"},
{tfName: "T", name: "T", type: "number", notSupported: true}
]
},
{
tfOpName: "Zeros",
category: "creation",
inputs: [
{start: 0, name: "shape", type: "number[]"}
],
attrs: [{tfName: "T", name: "dtype", type: "dtype"}]
},
{
tfOpName: "ZerosLike",
category: "creation",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [{tfName: "T", name: "dtype", type: "dtype"}]
},
{
tfOpName: "Multinomial",
category: "creation",
inputs: [
{start: 0, name: "logits", type: "tensor"},
{start: 1, name: "numSamples", type: "number"}
],
attrs: [
{tfName: "seed", name: "seed", type: "number"},
{tfName: "seed2", name: "seed2", type: "number"},
{tfName: "T", name: "dtype", type: "dtype"},
{tfName: "output_dtype", name: "output_dtype", type: "dtype"}
]
}
];
var creation = {
__proto__: null,
json: json$4
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$5 = [
{
tfOpName: "NonMaxSuppressionV2",
category: "dynamic",
inputs: [
{start: 0, name: "boxes", type: "tensor"},
{start: 1, name: "scores", type: "tensor"},
{start: 2, name: "maxOutputSize", type: "number"},
{start: 3, name: "iouThreshold", type: "number"}
]
},
{
tfOpName: "NonMaxSuppressionV3",
category: "dynamic",
inputs: [
{start: 0, name: "boxes", type: "tensor"},
{start: 1, name: "scores", type: "tensor"},
{start: 2, name: "maxOutputSize", type: "number"},
{start: 3, name: "iouThreshold", type: "number"},
{start: 4, name: "scoreThreshold", type: "number"}
]
},
{
tfOpName: "NonMaxSuppressionV4",
category: "dynamic",
inputs: [
{start: 0, name: "boxes", type: "tensor"},
{start: 1, name: "scores", type: "tensor"},
{start: 2, name: "maxOutputSize", type: "number"},
{start: 3, name: "iouThreshold", type: "number"},
{start: 4, name: "scoreThreshold", type: "number"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true},
{
tfName: "T_threshold",
name: "threshold",
type: "dtype",
notSupported: true
},
{
tfName: "pad_to_max_output_size",
name: "padToMaxOutputSize",
type: "bool"
}
]
},
{
tfOpName: "NonMaxSuppressionV5",
category: "dynamic",
inputs: [
{start: 0, name: "boxes", type: "tensor"},
{start: 1, name: "scores", type: "tensor"},
{start: 2, name: "maxOutputSize", type: "number"},
{start: 3, name: "iouThreshold", type: "number"},
{start: 4, name: "scoreThreshold", type: "number"},
{start: 5, name: "softNmsSigma", type: "number"}
]
},
{
tfOpName: "Where",
category: "dynamic",
inputs: [
{start: 0, name: "condition", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "ListDiff",
category: "dynamic",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "y", type: "tensor"}
],
attrs: [{
tfName: "T",
name: "dtype",
type: "dtype",
notSupported: true
}]
}
];
var dynamic = {
__proto__: null,
json: json$5
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$6 = [
{
tfOpName: "TopKV2",
category: "evaluation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "k", type: "number"}
],
attrs: [{tfName: "sorted", name: "sorted", type: "bool"}]
},
{
tfOpName: "Unique",
category: "evaluation",
inputs: [
{start: 0, name: "x", type: "tensor"}
]
},
{
tfOpName: "UniqueV2",
category: "evaluation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number"}
]
}
];
var evaluation = {
__proto__: null,
json: json$6
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$7 = [
{
tfOpName: "PlaceholderWithDefault",
category: "graph",
inputs: [
{start: 0, name: "default", type: "tensor"}
],
attrs: [
{tfName: "shape", name: "shape", type: "shape"},
{tfName: "dtype", name: "dtype", type: "dtype"}
]
},
{
tfOpName: "Placeholder",
category: "graph",
attrs: [
{tfName: "shape", name: "shape", type: "shape"},
{tfName: "dtype", name: "dtype", type: "dtype"}
]
},
{tfOpName: "Const", category: "graph"},
{
tfOpName: "Identity",
category: "graph",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "IdentityN",
category: "graph",
inputs: [{start: 0, end: 0, name: "x", type: "tensors"}]
},
{
tfOpName: "Snapshot",
category: "graph",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "Rank",
category: "graph",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "Size",
category: "graph",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "Shape",
category: "graph",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "ShapeN",
category: "graph",
inputs: [{start: 0, end: 0, name: "x", type: "tensors"}]
},
{
tfOpName: "Print",
category: "graph",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "data", type: "tensors"}
],
attrs: [
{tfName: "message", name: "message", type: "string"},
{
tfName: "first_n",
name: "firstN",
type: "number",
notSupported: true
},
{
tfName: "summarize",
name: "summarize",
type: "number",
defaultValue: 3
}
]
},
{tfOpName: "NoOp", category: "graph", inputs: []},
{
tfOpName: "StopGradient",
category: "graph",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "FakeQuantWithMinMaxVars",
category: "graph",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "min", name: "min", type: "number"},
{tfName: "max", name: "max", type: "number"}
]
}
];
var graph = {
__proto__: null,
json: json$7
};
var json$8 = [
{
tfOpName: "HashTable",
category: "hash_table",
inputs: [],
attrs: [
{tfName: "shared_name", name: "sharedName", type: "string"},
{
tfName: "use_node_name_sharing",
name: "useNodeNameSharing",
type: "bool"
},
{tfName: "key_dtype", name: "keyDType", type: "dtype"},
{tfName: "value_dtype", name: "valueDType", type: "dtype"}
]
},
{
tfOpName: "HashTableV2",
category: "hash_table",
inputs: [],
attrs: [
{tfName: "shared_name", name: "sharedName", type: "string"},
{
tfName: "use_node_name_sharing",
name: "useNodeNameSharing",
type: "bool"
},
{tfName: "key_dtype", name: "keyDType", type: "dtype"},
{tfName: "value_dtype", name: "valueDType", type: "dtype"}
]
},
{
tfOpName: "LookupTableImport",
category: "hash_table",
inputs: [
{start: 0, name: "tableHandle", type: "tensor"},
{start: 1, name: "keys", type: "tensor"},
{start: 2, name: "values", type: "tensor"}
],
attrs: [
{tfName: "Tin", name: "tIn", type: "dtype", notSupported: true},
{
tfName: "Tout",
name: "tOut",
type: "dtype",
notSupported: true
}
]
},
{
tfOpName: "LookupTableImportV2",
category: "hash_table",
inputs: [
{start: 0, name: "tableHandle", type: "tensor"},
{start: 1, name: "keys", type: "tensor"},
{start: 2, name: "values", type: "tensor"}
],
attrs: [
{tfName: "Tin", name: "tIn", type: "dtype", notSupported: true},
{
tfName: "Tout",
name: "tOut",
type: "dtype",
notSupported: true
}
]
},
{
tfOpName: "LookupTableFind",
category: "hash_table",
inputs: [
{start: 0, name: "tableHandle", type: "tensor"},
{start: 1, name: "keys", type: "tensor"},
{start: 2, name: "defaultValue", type: "tensor"}
],
attrs: [
{tfName: "Tin", name: "tIn", type: "dtype", notSupported: true},
{
tfName: "Tout",
name: "tOut",
type: "dtype",
notSupported: true
}
]
},
{
tfOpName: "LookupTableFindV2",
category: "hash_table",
inputs: [
{start: 0, name: "tableHandle", type: "tensor"},
{start: 1, name: "keys", type: "tensor"},
{start: 2, name: "defaultValue", type: "tensor"}
],
attrs: [
{tfName: "Tin", name: "tIn", type: "dtype", notSupported: true},
{
tfName: "Tout",
name: "tOut",
type: "dtype",
notSupported: true
}
]
}
];
var hashTable = {
__proto__: null,
json: json$8
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$9 = [
{
tfOpName: "ResizeBilinear",
category: "image",
inputs: [
{start: 0, name: "images", type: "tensor"},
{start: 1, name: "size", type: "number[]"}
],
attrs: [
{tfName: "align_corners", name: "alignCorners", type: "bool"},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "ResizeNearestNeighbor",
category: "image",
inputs: [
{start: 0, name: "images", type: "tensor"},
{start: 1, name: "size", type: "number[]"}
],
attrs: [
{tfName: "align_corners", name: "alignCorners", type: "bool"},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "CropAndResize",
category: "image",
inputs: [
{start: 0, name: "image", type: "tensor"},
{start: 1, name: "boxes", type: "tensor"},
{start: 2, name: "boxInd", type: "tensor"},
{start: 3, name: "cropSize", type: "number[]"}
],
attrs: [
{tfName: "method", name: "method", type: "string"},
{
tfName: "extrapolation_value",
name: "extrapolationValue",
type: "number"
}
]
}
];
var image = {
__proto__: null,
json: json$9
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$a = [
{
tfOpName: "Equal",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "NotEqual",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Greater",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "GreaterEqual",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Less",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "LessEqual",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "LogicalAnd",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "LogicalNot",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "LogicalOr",
category: "logical",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Select",
category: "logical",
inputs: [
{start: 0, name: "condition", type: "tensor"},
{start: 1, name: "a", type: "tensor"},
{start: 2, name: "b", type: "tensor"}
],
attrs: [
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "SelectV2",
category: "logical",
inputs: [
{start: 0, name: "condition", type: "tensor"},
{start: 1, name: "a", type: "tensor"},
{start: 2, name: "b", type: "tensor"}
],
attrs: [{
tfName: "T",
name: "dtype",
type: "dtype",
notSupported: true
}]
}
];
var logical = {
__proto__: null,
json: json$a
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$b = [
{
tfOpName: "_FusedMatMul",
category: "matrices",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"},
{start: 2, end: 0, name: "args", type: "tensors"}
],
attrs: [
{tfName: "num_args", name: "numArgs", type: "number"},
{
tfName: "fused_ops",
name: "fusedOps",
type: "string[]",
defaultValue: []
},
{
tfName: "epsilon",
name: "epsilon",
type: "number",
defaultValue: 1e-4
},
{
tfName: "transpose_a",
name: "transposeA",
type: "bool",
defaultValue: false
},
{
tfName: "transpose_b",
name: "transposeB",
type: "bool",
defaultValue: false
},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "MatMul",
category: "matrices",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{
tfName: "transpose_a",
name: "transposeA",
type: "bool",
defaultValue: false
},
{
tfName: "transpose_b",
name: "transposeB",
type: "bool",
defaultValue: false
},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "BatchMatMul",
category: "matrices",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{
tfName: "adj_x",
name: "transposeA",
type: "bool",
defaultValue: false
},
{
tfName: "adj_y",
name: "transposeB",
type: "bool",
defaultValue: false
},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "BatchMatMulV2",
category: "matrices",
inputs: [
{start: 0, name: "a", type: "tensor"},
{start: 1, name: "b", type: "tensor"}
],
attrs: [
{
tfName: "adj_x",
name: "transposeA",
type: "bool",
defaultValue: false
},
{
tfName: "adj_y",
name: "transposeB",
type: "bool",
defaultValue: false
},
{tfName: "T", name: "dtype", type: "dtype", notSupported: true}
]
},
{
tfOpName: "Transpose",
category: "matrices",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "perm", type: "number[]"}
],
attrs: [{
tfName: "T",
name: "dtype",
type: "dtype",
notSupported: true
}]
}
];
var matrices = {
__proto__: null,
json: json$b
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$c = [
{
tfOpName: "FusedBatchNorm",
category: "normalization",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "scale", type: "tensor"},
{start: 2, name: "offset", type: "tensor"},
{start: 3, name: "mean", type: "tensor"},
{start: 4, name: "variance", type: "tensor"}
],
attrs: [
{
tfName: "epsilon",
name: "epsilon",
type: "number",
defaultValue: 1e-3
},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
notSupported: true
}
]
},
{
tfOpName: "FusedBatchNormV2",
category: "normalization",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "scale", type: "tensor"},
{start: 2, name: "offset", type: "tensor"},
{start: 3, name: "mean", type: "tensor"},
{start: 4, name: "variance", type: "tensor"}
],
attrs: [
{
tfName: "epsilon",
name: "epsilon",
type: "number",
defaultValue: 1e-3
},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
notSupported: true
}
]
},
{
tfOpName: "FusedBatchNormV3",
category: "normalization",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "scale", type: "tensor"},
{start: 2, name: "offset", type: "tensor"},
{start: 3, name: "mean", type: "tensor"},
{start: 4, name: "variance", type: "tensor"}
],
attrs: [
{
tfName: "epsilon",
name: "epsilon",
type: "number",
defaultValue: 1e-3
},
{
tfName: "data_format",
name: "dataFormat",
type: "string",
notSupported: true
}
]
},
{
tfOpName: "LRN",
category: "normalization",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{
tfName: "depth_radius",
name: "radius",
type: "number",
defaultValue: 5
},
{tfName: "bias", name: "bias", type: "number", defaultValue: 1},
{
tfName: "alpha",
name: "alpha",
type: "number",
defaultValue: 1
},
{
tfName: "beta",
name: "beta",
type: "number",
defaultValue: 0.5
}
]
},
{
tfOpName: "Softmax",
category: "normalization",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "LogSoftmax",
category: "normalization",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "SparseToDense",
category: "normalization",
inputs: [
{start: 0, name: "sparseIndices", type: "tensor"},
{start: 1, name: "outputShape", type: "number[]"},
{start: 2, name: "sparseValues", type: "tensor"},
{start: 3, name: "defaultValue", type: "tensor"}
],
attrs: [{
tfName: "validate_indices",
name: "validateIndices",
type: "bool",
defaultValue: true,
notSupported: true
}]
}
];
var normalization = {
__proto__: null,
json: json$c
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$d = [
{
tfOpName: "Max",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number[]"}
],
attrs: [{tfName: "keep_dims", name: "keepDims", type: "bool"}]
},
{
tfOpName: "Mean",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number[]"}
],
attrs: [{tfName: "keep_dims", name: "keepDims", type: "bool"}]
},
{
tfOpName: "Min",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number[]"}
],
attrs: [{tfName: "keep_dims", name: "keepDims", type: "bool"}]
},
{
tfOpName: "Sum",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number[]"}
],
attrs: [{tfName: "keep_dims", name: "keepDims", type: "bool"}]
},
{
tfOpName: "All",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number[]"}
],
attrs: [{tfName: "keep_dims", name: "keepDims", type: "bool"}]
},
{
tfOpName: "Any",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number[]"}
],
attrs: [{tfName: "keep_dims", name: "keepDims", type: "bool"}]
},
{
tfOpName: "ArgMax",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number"}
]
},
{
tfOpName: "ArgMin",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number"}
]
},
{
tfOpName: "Prod",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number[]"}
],
attrs: [{tfName: "keep_dims", name: "keepDims", type: "bool"}]
},
{
tfOpName: "Cumsum",
category: "reduction",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number"}
],
attrs: [
{tfName: "exclusive", name: "exclusive", type: "bool"},
{tfName: "reverse", name: "reverse", type: "bool"}
]
}
];
var reduction = {
__proto__: null,
json: json$d
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$e = [
{
tfOpName: "ConcatV2",
category: "slice_join",
inputs: [
{start: 0, end: -1, name: "tensors", type: "tensors"},
{start: -1, name: "axis", type: "number"}
],
attrs: [{tfName: "N", name: "n", type: "number", defaultValue: 2}]
},
{
tfOpName: "Concat",
category: "slice_join",
inputs: [
{start: 1, end: 0, name: "tensors", type: "tensors"},
{start: 0, name: "axis", type: "number"}
],
attrs: [{tfName: "N", name: "n", type: "number", defaultValue: 2}]
},
{
tfOpName: "GatherV2",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "indices", type: "tensor"},
{start: 2, name: "axis", type: "number", defaultValue: 0}
]
},
{
tfOpName: "Gather",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "indices", type: "tensor"}
],
attrs: [
{tfName: "axis", name: "axis", type: "number", defaultValue: 0},
{
tfName: "validate_indices",
name: "validateIndices",
type: "bool",
notSupported: true
}
]
},
{
tfOpName: "Reverse",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "dims", type: "bool", notSupported: true}
]
},
{
tfOpName: "ReverseV2",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number[]"}
]
},
{
tfOpName: "Slice",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "begin", type: "number[]"},
{start: 2, name: "size", type: "number[]"}
]
},
{
tfOpName: "StridedSlice",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "begin", type: "number[]"},
{start: 2, name: "end", type: "number[]"},
{start: 3, name: "strides", type: "number[]"}
],
attrs: [
{
tfName: "begin_mask",
name: "beginMask",
type: "number",
defaultValue: 0
},
{
tfName: "end_mask",
name: "endMask",
type: "number",
defaultValue: 0
},
{
tfName: "new_axis_mask",
name: "newAxisMask",
type: "number",
defaultValue: 0
},
{
tfName: "ellipsis_mask",
name: "ellipsisMask",
type: "number",
defaultValue: 0
},
{
tfName: "shrink_axis_mask",
name: "shrinkAxisMask",
type: "number",
defaultValue: 0
}
]
},
{
tfOpName: "Pack",
category: "slice_join",
inputs: [
{start: 0, end: 0, name: "tensors", type: "tensors"}
],
attrs: [
{tfName: "axis", name: "axis", type: "number", defaultValue: 0}
]
},
{
tfOpName: "Unpack",
category: "slice_join",
inputs: [
{start: 0, name: "tensor", type: "tensor"}
],
attrs: [
{tfName: "axis", name: "axis", type: "number", defaultValue: 0},
{
tfName: "num",
name: "num",
type: "number",
defaultValue: 0,
notSupported: true
}
]
},
{
tfOpName: "Tile",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "reps", type: "number[]"}
]
},
{
tfOpName: "Split",
category: "slice_join",
inputs: [
{start: 0, name: "axis", type: "number", defaultValue: 0},
{start: 1, name: "x", type: "tensor"}
],
attrs: [{
tfName: "num_split",
name: "numOrSizeSplits",
type: "number",
defaultValue: 1
}]
},
{
tfOpName: "SplitV",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "numOrSizeSplits", type: "number[]"},
{start: 2, name: "axis", type: "number", defaultValue: 0}
]
},
{
tfOpName: "ScatterNd",
category: "slice_join",
inputs: [
{start: 0, name: "indices", type: "tensor"},
{start: 1, name: "values", type: "tensor"},
{start: 2, name: "shape", type: "number[]"}
]
},
{
tfOpName: "GatherNd",
category: "slice_join",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "indices", type: "tensor"}
]
},
{
tfOpName: "SparseToDense",
category: "slice_join",
inputs: [
{start: 0, name: "sparseIndices", type: "tensor"},
{start: 1, name: "outputShape", type: "number[]"},
{start: 2, name: "sparseValues", type: "tensor"},
{start: 3, name: "defaultValue", type: "tensor"}
],
attrs: [{
tfName: "validate_indices",
name: "validateIndices",
type: "bool",
defaultValue: false,
notSupported: true
}]
}
];
var sliceJoin = {
__proto__: null,
json: json$e
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$f = [
{
tfOpName: "FFT",
category: "spectral",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "IFFT",
category: "spectral",
inputs: [{start: 0, name: "x", type: "tensor"}]
},
{
tfOpName: "RFFT",
category: "spectral",
inputs: [
{start: 0, name: "x", type: "tensor"},
{
start: 1,
name: "fft_length",
type: "number",
notSupported: true
}
]
},
{
tfOpName: "IRFFT",
category: "spectral",
inputs: [
{start: 0, name: "x", type: "tensor"},
{
start: 1,
name: "fft_length",
type: "number",
notSupported: true
}
]
}
];
var spectral = {
__proto__: null,
json: json$f
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$g = [
{
tfOpName: "Cast",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{
tfName: "SrcT",
name: "sdtype",
type: "dtype",
notSupported: true
},
{tfName: "DstT", name: "dtype", type: "dtype"}
]
},
{
tfOpName: "ExpandDims",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "axis", type: "number"}
]
},
{
tfOpName: "MirrorPad",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "padding", type: "number[]"}
],
attrs: [{tfName: "mode", name: "mode", type: "string"}]
},
{
tfOpName: "Pad",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "padding", type: "number[]"}
],
attrs: [{
tfName: "constant_value",
name: "constantValue",
type: "number",
defaultValue: 0
}]
},
{
tfOpName: "PadV2",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "padding", type: "number[]"},
{
start: 2,
name: "constantValue",
type: "number",
defaultValue: 0
}
]
},
{
tfOpName: "Reshape",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "shape", type: "number[]"}
]
},
{
tfOpName: "Squeeze",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [{
tfName: "axis",
tfDeprecatedName: "squeeze_dims",
name: "axis",
type: "number[]"
}]
},
{
tfOpName: "SpaceToBatchND",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "blockShape", type: "number[]"},
{start: 2, name: "paddings", type: "number[]"}
]
},
{
tfOpName: "BatchToSpaceND",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "blockShape", type: "number[]"},
{start: 2, name: "crops", type: "number[]"}
]
},
{
tfOpName: "DepthToSpace",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"}
],
attrs: [
{tfName: "block_size", name: "blockSize", type: "number"},
{tfName: "data_format", name: "dataFormat", type: "string"}
]
},
{
tfOpName: "BroadcastTo",
category: "transformation",
inputs: [
{start: 0, name: "x", type: "tensor"},
{start: 1, name: "shape", type: "number[]"}
],
attrs: []
}
];
var transformation = {
__proto__: null,
json: json$g
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var OperationMapper = function() {
function OperationMapper2() {
var ops = [
arithmetic,
basicMath,
control,
convolution,
creation,
dynamic,
evaluation,
logical,
image,
graph,
matrices,
normalization,
reduction,
sliceJoin,
spectral,
transformation,
hashTable
];
var mappersJson = [].concat.apply([], __spread(ops.map(function(op) {
return op.json;
})));
this.opMappers = mappersJson.reduce(function(map, mapper) {
map[mapper.tfOpName] = mapper;
return map;
}, {});
}
Object.defineProperty(OperationMapper2, "Instance", {
get: function() {
return this._instance || (this._instance = new this());
},
enumerable: true,
configurable: true
});
OperationMapper2.prototype.transformGraph = function(graph2, signature) {
var _this2 = this;
if (signature === void 0) {
signature = {};
}
var tfNodes = graph2.node;
var placeholders = [];
var weights = [];
var initNodes = [];
var nodes = tfNodes.reduce(function(map, node) {
map[node.name] = _this2.mapNode(node);
if (node.op.startsWith("Placeholder")) {
placeholders.push(map[node.name]);
} else if (node.op === "Const") {
weights.push(map[node.name]);
} else if (node.input == null || node.input.length === 0) {
initNodes.push(map[node.name]);
}
return map;
}, {});
var inputs = [];
var outputs = [];
var inputNodeNameToKey = {};
var outputNodeNameToKey = {};
if (signature != null) {
inputNodeNameToKey = this.mapSignatureEntries(signature.inputs);
outputNodeNameToKey = this.mapSignatureEntries(signature.outputs);
}
var allNodes = Object.keys(nodes);
allNodes.forEach(function(key) {
var node = nodes[key];
node.inputNames.forEach(function(name) {
var _a = __read(getNodeNameAndIndex(name), 1), nodeName = _a[0];
node.inputs.push(nodes[nodeName]);
nodes[nodeName].children.push(node);
});
});
if (Object.keys(outputNodeNameToKey).length === 0) {
allNodes.forEach(function(key) {
var node = nodes[key];
if (node.children.length === 0) {
outputs.push(node);
}
});
} else {
Object.keys(outputNodeNameToKey).forEach(function(name) {
var _a = __read(getNodeNameAndIndex(name), 1), nodeName = _a[0];
var node = nodes[nodeName];
if (node != null) {
node.signatureKey = outputNodeNameToKey[name];
outputs.push(node);
}
});
}
if (Object.keys(inputNodeNameToKey).length > 0) {
Object.keys(inputNodeNameToKey).forEach(function(name) {
var _a = __read(getNodeNameAndIndex(name), 1), nodeName = _a[0];
var node = nodes[nodeName];
if (node) {
node.signatureKey = inputNodeNameToKey[name];
inputs.push(node);
}
});
} else {
inputs = placeholders;
}
var functions = {};
if (graph2.library != null && graph2.library.function != null) {
functions = graph2.library.function.reduce(function(functions2, func) {
functions2[func.signature.name] = _this2.mapFunction(func);
return functions2;
}, {});
}
var result = {nodes, inputs, outputs, weights, placeholders, signature, functions};
if (initNodes.length > 0) {
result.initNodes = initNodes;
}
return result;
};
OperationMapper2.prototype.mapSignatureEntries = function(entries) {
return Object.keys(entries || {}).reduce(function(prev, curr) {
prev[entries[curr].name] = curr;
return prev;
}, {});
};
OperationMapper2.prototype.mapNode = function(node) {
var mapper = getRegisteredOp(node.op) || this.opMappers[node.op] || {};
if (node.attr == null) {
node.attr = {};
}
var newNode = {
name: node.name,
op: node.op,
category: mapper.category,
inputNames: (node.input || []).map(function(input) {
return input.startsWith("^") ? input.substr(1) : input;
}),
inputs: [],
children: [],
inputParams: {},
attrParams: {},
rawAttrs: node.attr
};
if (mapper.inputs != null) {
newNode.inputParams = mapper.inputs.reduce(function(map, param) {
map[param.name] = {
type: param.type,
inputIndexStart: param.start,
inputIndexEnd: param.end
};
return map;
}, {});
}
if (mapper.attrs != null) {
newNode.attrParams = mapper.attrs.reduce(function(map, param) {
var type = param.type;
var value = void 0;
switch (param.type) {
case "string":
value = getStringParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getStringParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "string[]":
value = getStringArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getStringArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "number":
value = getNumberParam(node.attr, param.tfName, param.defaultValue || 0);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getNumberParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "number[]":
value = getNumericArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getNumericArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "bool":
value = getBoolParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getBoolParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "bool[]":
value = getBoolArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getBoolArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "shape":
value = getTensorShapeParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getTensorShapeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "shape[]":
value = getTensorShapeArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getTensorShapeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "dtype":
value = getDtypeParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getDtypeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "dtype[]":
value = getDtypeArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getDtypeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "func":
value = getFuncParam(node.attr, param.tfName, param.defaultValue);
if (value === void 0 && !!param.tfDeprecatedName) {
value = getFuncParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case "tensor":
case "tensors":
break;
default:
throw new Error("Unsupported param type: " + param.type + " for op: " + node.op);
}
map[param.name] = {value, type};
return map;
}, {});
}
return newNode;
};
OperationMapper2.prototype.mapFunction = function(functionDef) {
var _this2 = this;
var tfNodes = functionDef.nodeDef;
var placeholders = [];
var weights = [];
var nodes = {};
if (tfNodes != null) {
nodes = tfNodes.reduce(function(map, node) {
map[node.name] = _this2.mapNode(node);
if (node.op === "Const") {
weights.push(map[node.name]);
}
return map;
}, {});
}
var inputs = [];
var outputs = [];
functionDef.signature.inputArg.forEach(function(arg) {
var _a = __read(getNodeNameAndIndex(arg.name), 1), nodeName = _a[0];
var node = {
name: nodeName,
op: "Placeholder",
inputs: [],
inputNames: [],
category: "graph",
inputParams: {},
attrParams: {dtype: {value: parseDtypeParam(arg.type), type: "dtype"}},
children: []
};
node.signatureKey = arg.name;
inputs.push(node);
nodes[nodeName] = node;
});
var allNodes = Object.keys(nodes);
allNodes.forEach(function(key) {
var node = nodes[key];
node.inputNames.forEach(function(name) {
var _a = __read(getNodeNameAndIndex(name), 1), nodeName = _a[0];
node.inputs.push(nodes[nodeName]);
nodes[nodeName].children.push(node);
});
});
var returnNodeMap = functionDef.ret;
functionDef.signature.outputArg.forEach(function(output) {
var _a = __read(getNodeNameAndIndex(returnNodeMap[output.name]), 2), nodeName = _a[0], index = _a[1];
var node = nodes[nodeName];
if (node != null) {
node.defaultOutput = index;
outputs.push(node);
}
});
var signature = this.mapArgsToSignature(functionDef);
return {nodes, inputs, outputs, weights, placeholders, signature};
};
OperationMapper2.prototype.mapArgsToSignature = function(functionDef) {
var _this2 = this;
return {
methodName: functionDef.signature.name,
inputs: functionDef.signature.inputArg.reduce(function(map, arg) {
map[arg.name] = _this2.mapArgToTensorInfo(arg);
return map;
}, {}),
outputs: functionDef.signature.outputArg.reduce(function(map, arg) {
map[arg.name] = _this2.mapArgToTensorInfo(arg, functionDef.ret);
return map;
}, {})
};
};
OperationMapper2.prototype.mapArgToTensorInfo = function(arg, nameMap) {
var name = arg.name;
if (nameMap != null) {
name = nameMap[name];
}
return {name, dtype: arg.type};
};
return OperationMapper2;
}();
function decodeBase64(text) {
var global2 = tfOps.env().global;
if (typeof global2.atob !== "undefined") {
return global2.atob(text);
} else if (typeof Buffer !== "undefined") {
return new Buffer(text, "base64").toString();
} else {
throw new Error("Unable to decode base64 in this environment. Missing built-in atob() or Buffer()");
}
}
function parseStringParam(s, keepCase) {
var value = Array.isArray(s) ? String.fromCharCode.apply(null, s) : decodeBase64(s);
return keepCase ? value : value.toLowerCase();
}
function getStringParam(attrs, name, def, keepCase) {
if (keepCase === void 0) {
keepCase = false;
}
var param = attrs[name];
if (param != null) {
return parseStringParam(param.s, keepCase);
}
return def;
}
function getBoolParam(attrs, name, def) {
var param = attrs[name];
return param ? param.b : def;
}
function getNumberParam(attrs, name, def) {
var param = attrs[name] || {};
var value = param["i"] != null ? param["i"] : param["f"] != null ? param["f"] : def;
return typeof value === "number" ? value : parseInt(value, 10);
}
function parseDtypeParam(value) {
if (typeof value === "string") {
value = DataType[value];
}
switch (value) {
case DataType.DT_FLOAT:
return "float32";
case DataType.DT_INT32:
case DataType.DT_INT64:
case DataType.DT_INT8:
case DataType.DT_UINT8:
return "int32";
case DataType.DT_BOOL:
return "bool";
case DataType.DT_DOUBLE:
return "float32";
case DataType.DT_STRING:
return "string";
default:
return null;
}
}
function getFuncParam(attrs, name, def) {
var param = attrs[name];
if (param && param.func) {
return param.func.name;
}
return def;
}
function getDtypeParam(attrs, name, def) {
var param = attrs[name];
if (param && param.type) {
return parseDtypeParam(param.type);
}
return def;
}
function getDtypeArrayParam(attrs, name, def) {
var param = attrs[name];
if (param && param.list && param.list.type) {
return param.list.type.map(function(v) {
return parseDtypeParam(v);
});
}
return def;
}
function parseTensorShapeParam(shape) {
if (shape.unknownRank) {
return void 0;
}
if (shape.dim != null) {
return shape.dim.map(function(dim) {
return typeof dim.size === "number" ? dim.size : parseInt(dim.size, 10);
});
}
return [];
}
function getTensorShapeParam(attrs, name, def) {
var param = attrs[name];
if (param && param.shape) {
return parseTensorShapeParam(param.shape);
}
return def;
}
function getNumericArrayParam(attrs, name, def) {
var param = attrs[name];
if (param) {
return ((param.list.f && param.list.f.length ? param.list.f : param.list.i) || []).map(function(v) {
return typeof v === "number" ? v : parseInt(v, 10);
});
}
return def;
}
function getStringArrayParam(attrs, name, def, keepCase) {
if (keepCase === void 0) {
keepCase = false;
}
var param = attrs[name];
if (param && param.list && param.list.s) {
return param.list.s.map(function(v) {
return parseStringParam(v, keepCase);
});
}
return def;
}
function getTensorShapeArrayParam(attrs, name, def) {
var param = attrs[name];
if (param && param.list && param.list.shape) {
return param.list.shape.map(function(v) {
return parseTensorShapeParam(v);
});
}
return def;
}
function getBoolArrayParam(attrs, name, def) {
var param = attrs[name];
if (param && param.list && param.list.b) {
return param.list.b;
}
return def;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var NodeValueImpl = function() {
function NodeValueImpl2(node, tensorMap, context) {
var _this2 = this;
this.node = node;
this.tensorMap = tensorMap;
this.context = context;
this.inputs = [];
this.attrs = {};
this.inputs = node.inputNames.map(function(name) {
return _this2.getInput(name);
});
if (node.rawAttrs != null) {
this.attrs = Object.keys(node.rawAttrs).reduce(function(attrs, key) {
attrs[key] = _this2.getAttr(key);
return attrs;
}, {});
}
}
NodeValueImpl2.prototype.getInput = function(name) {
return getTensor(name, this.tensorMap, this.context);
};
NodeValueImpl2.prototype.getAttr = function(name, defaultValue) {
var value = this.node.rawAttrs[name];
if (value.tensor != null) {
return getTensor(name, this.tensorMap, this.context);
}
if (value.i != null || value.f != null) {
return getNumberParam(this.node.rawAttrs, name, defaultValue);
}
if (value.s != null) {
return getStringParam(this.node.rawAttrs, name, defaultValue);
}
if (value.b != null) {
return getBoolParam(this.node.rawAttrs, name, defaultValue);
}
if (value.shape != null) {
return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
}
if (value.type != null) {
return getDtypeParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list != null) {
if (value.list.i != null || value.list.f != null) {
return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list.s != null) {
return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list.shape != null) {
return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list.b != null) {
return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list.type != null) {
return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
}
}
return defaultValue;
};
return NodeValueImpl2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp = function(node, tensorMap, context) {
switch (node.op) {
case "BiasAdd":
case "AddV2":
case "Add": {
return [tfOps.add(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "AddN": {
return [tfOps.addN(getParamValue("tensors", node, tensorMap, context))];
}
case "FloorMod":
case "Mod":
return [tfOps.mod(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
case "Mul":
return [tfOps.mul(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
case "RealDiv":
case "Div": {
return [tfOps.div(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "DivNoNan": {
return [tfOps.divNoNan(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "FloorDiv": {
return [tfOps.floorDiv(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "Sub": {
return [tfOps.sub(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "Minimum": {
return [tfOps.minimum(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "Maximum": {
return [tfOps.maximum(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "Pow": {
return [tfOps.pow(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "SquaredDifference": {
return [tfOps.squaredDifference(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$1 = function(node, tensorMap, context) {
switch (node.op) {
case "Abs":
case "ComplexAbs":
return [tfOps.abs(getParamValue("x", node, tensorMap, context))];
case "Acos":
return [tfOps.acos(getParamValue("x", node, tensorMap, context))];
case "Acosh":
return [tfOps.acosh(getParamValue("x", node, tensorMap, context))];
case "Asin":
return [tfOps.asin(getParamValue("x", node, tensorMap, context))];
case "Asinh":
return [tfOps.asinh(getParamValue("x", node, tensorMap, context))];
case "Atan":
return [tfOps.atan(getParamValue("x", node, tensorMap, context))];
case "Atan2":
return [tfOps.atan2(getParamValue("x", node, tensorMap, context), getParamValue("y", node, tensorMap, context))];
case "Atanh":
return [tfOps.atanh(getParamValue("x", node, tensorMap, context))];
case "Ceil":
return [tfOps.ceil(getParamValue("x", node, tensorMap, context))];
case "Complex":
return [tfOps.complex(getParamValue("real", node, tensorMap, context), getParamValue("imag", node, tensorMap, context))];
case "Cos":
return [tfOps.cos(getParamValue("x", node, tensorMap, context))];
case "Cosh":
return [tfOps.cosh(getParamValue("x", node, tensorMap, context))];
case "Elu":
return [tfOps.elu(getParamValue("x", node, tensorMap, context))];
case "Erf":
return [tfOps.erf(getParamValue("x", node, tensorMap, context))];
case "Exp":
return [tfOps.exp(getParamValue("x", node, tensorMap, context))];
case "Expm1": {
return [tfOps.expm1(getParamValue("x", node, tensorMap, context))];
}
case "Floor":
return [tfOps.floor(getParamValue("x", node, tensorMap, context))];
case "Log":
return [tfOps.log(getParamValue("x", node, tensorMap, context))];
case "Log1p": {
return [tfOps.log1p(getParamValue("x", node, tensorMap, context))];
}
case "Imag":
return [tfOps.imag(getParamValue("x", node, tensorMap, context))];
case "Neg":
return [tfOps.neg(getParamValue("x", node, tensorMap, context))];
case "Reciprocal": {
return [tfOps.reciprocal(getParamValue("x", node, tensorMap, context))];
}
case "Real":
return [tfOps.real(getParamValue("x", node, tensorMap, context))];
case "Relu":
return [tfOps.relu(getParamValue("x", node, tensorMap, context))];
case "Round": {
return [tfOps.round(getParamValue("x", node, tensorMap, context))];
}
case "Selu":
return [tfOps.selu(getParamValue("x", node, tensorMap, context))];
case "Sigmoid":
return [tfOps.sigmoid(getParamValue("x", node, tensorMap, context))];
case "Sin":
return [tfOps.sin(getParamValue("x", node, tensorMap, context))];
case "Sign": {
return [tfOps.sign(getParamValue("x", node, tensorMap, context))];
}
case "Sinh": {
return [tfOps.sinh(getParamValue("x", node, tensorMap, context))];
}
case "Softplus": {
return [tfOps.softplus(getParamValue("x", node, tensorMap, context))];
}
case "Sqrt": {
return [tfOps.sqrt(getParamValue("x", node, tensorMap, context))];
}
case "Square": {
return [tfOps.square(getParamValue("x", node, tensorMap, context))];
}
case "Tanh": {
return [tfOps.tanh(getParamValue("x", node, tensorMap, context))];
}
case "Tan":
return [tfOps.tan(getParamValue("x", node, tensorMap, context))];
case "Relu6":
case "ClipByValue":
return [tfOps.clipByValue(getParamValue("x", node, tensorMap, context), getParamValue("clipValueMin", node, tensorMap, context), getParamValue("clipValueMax", node, tensorMap, context))];
case "Rsqrt":
return [tfOps.rsqrt(getTensor(node.inputNames[0], tensorMap, context))];
case "Prod":
return [tfOps.prod(getParamValue("x", node, tensorMap, context), getParamValue("axes", node, tensorMap, context))];
case "LeakyRelu":
return [tfOps.leakyRelu(getParamValue("x", node, tensorMap, context), getParamValue("alpha", node, tensorMap, context))];
case "Prelu":
return [tfOps.prelu(getParamValue("x", node, tensorMap, context), getParamValue("alpha", node, tensorMap, context))];
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function assertShapesMatchAllowUndefinedSize(shapeA, shapeB, errorMessagePrefix) {
if (errorMessagePrefix === void 0) {
errorMessagePrefix = "";
}
tfOps.util.assert(shapesEqualAllowUndefinedSize(shapeA, shapeB), function() {
return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match");
});
}
function shapesEqualAllowUndefinedSize(n1, n2) {
if (n1.length !== n2.length) {
return false;
}
for (var i = 0; i < n1.length; i++) {
if (n1[i] !== -1 && n2[i] !== -1 && n1[i] !== n2[i]) {
return false;
}
}
return true;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TensorArray = function() {
function TensorArray2(name, dtype, maxSize, elementShape, identicalElementShapes, dynamicSize, clearAfterRead) {
this.name = name;
this.dtype = dtype;
this.maxSize = maxSize;
this.elementShape = elementShape;
this.identicalElementShapes = identicalElementShapes;
this.dynamicSize = dynamicSize;
this.clearAfterRead = clearAfterRead;
this.tensors = [];
this.closed_ = false;
this.idTensor = tfOps.scalar(0);
tfOps.keep(this.idTensor);
}
Object.defineProperty(TensorArray2.prototype, "id", {
get: function() {
return this.idTensor.id;
},
enumerable: true,
configurable: true
});
Object.defineProperty(TensorArray2.prototype, "closed", {
get: function() {
return this.closed_;
},
enumerable: true,
configurable: true
});
TensorArray2.prototype.clearAndClose = function(keepIds) {
this.tensors.forEach(function(tensor) {
if (keepIds == null || !keepIds.has(tensor.tensor.id)) {
tensor.tensor.dispose();
}
});
this.tensors = [];
this.closed_ = true;
this.idTensor.dispose();
};
TensorArray2.prototype.size = function() {
return this.tensors.length;
};
TensorArray2.prototype.read = function(index) {
if (this.closed_) {
throw new Error("TensorArray " + this.name + " has already been closed.");
}
if (index < 0 || index >= this.size()) {
throw new Error("Tried to read from index " + index + ", but array size is: " + this.size());
}
var tensorWithState = this.tensors[index];
if (tensorWithState.cleared) {
throw new Error("TensorArray " + this.name + ": Could not read index " + index + " twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?).");
}
if (this.clearAfterRead) {
tensorWithState.cleared = true;
}
tensorWithState.read = true;
return tensorWithState.tensor;
};
TensorArray2.prototype.readMany = function(indices) {
var _this2 = this;
return indices.map(function(index) {
return _this2.read(index);
});
};
TensorArray2.prototype.write = function(index, tensor) {
if (this.closed_) {
throw new Error("TensorArray " + this.name + " has already been closed.");
}
if (index < 0 || !this.dynamicSize && index >= this.maxSize) {
throw new Error("Tried to write to index " + index + ", but array is not resizeable and size is: " + this.maxSize);
}
var t = this.tensors[index] || {};
if (tensor.dtype !== this.dtype) {
throw new Error("TensorArray " + this.name + ": Could not write to TensorArray index " + index + ",\n because the value dtype is " + tensor.dtype + ", but TensorArray dtype is " + this.dtype + ".");
}
if (this.size() === 0 && (this.elementShape == null || this.elementShape.length === 0)) {
this.elementShape = tensor.shape;
}
assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, "TensorArray " + this.name + ": Could not write to TensorArray index " + index + ".");
if (t.read) {
throw new Error("TensorArray " + this.name + ": Could not write to TensorArray index " + index + ", because it has already been read.");
}
if (t.written) {
throw new Error("TensorArray " + this.name + ": Could not write to TensorArray index " + index + ", because it has already been written.");
}
t.tensor = tensor;
tfOps.keep(tensor);
t.written = true;
this.tensors[index] = t;
};
TensorArray2.prototype.writeMany = function(indices, tensors) {
var _this2 = this;
if (indices.length !== tensors.length) {
throw new Error("TensorArray " + this.name + ": could not write multiple tensors," + ("because the index size: " + indices.length + " is not the same as tensors size: " + tensors.length + "."));
}
indices.forEach(function(i, index) {
return _this2.write(i, tensors[index]);
});
};
TensorArray2.prototype.gather = function(indices, dtype) {
if (!!dtype && dtype !== this.dtype) {
throw new Error("TensorArray dtype is " + this.dtype + " but gather requested dtype " + dtype);
}
if (!indices) {
indices = [];
for (var i = 0; i < this.size(); i++) {
indices.push(i);
}
} else {
indices = indices.slice(0, this.size());
}
if (indices.length === 0) {
return tfOps.tensor([], [0].concat(this.elementShape));
}
var tensors = this.readMany(indices);
assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, "TensorArray shape mismatch: ");
return tfOps.stack(tensors, 0);
};
TensorArray2.prototype.concat = function(dtype) {
if (!!dtype && dtype !== this.dtype) {
throw new Error("TensorArray dtype is " + this.dtype + " but concat requested dtype " + dtype);
}
if (this.size() === 0) {
return tfOps.tensor([], [0].concat(this.elementShape));
}
var indices = [];
for (var i = 0; i < this.size(); i++) {
indices.push(i);
}
var tensors = this.readMany(indices);
assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, "TensorArray shape mismatch: tensor array shape (" + this.elementShape + ") vs first tensor shape (" + tensors[0].shape + ")");
return tfOps.concat(tensors, 0);
};
TensorArray2.prototype.scatter = function(indices, tensor) {
if (tensor.dtype !== this.dtype) {
throw new Error("TensorArray dtype is " + this.dtype + " but tensor has dtype " + tensor.dtype);
}
if (indices.length !== tensor.shape[0]) {
throw new Error("Expected len(indices) == tensor.shape[0], but saw: " + indices.length + " vs. " + tensor.shape[0]);
}
var maxIndex = Math.max.apply(Math, __spread(indices));
if (!this.dynamicSize && maxIndex >= this.maxSize) {
throw new Error("Max index must be < array size (" + maxIndex + " vs. " + this.maxSize + ")");
}
this.writeMany(indices, tfOps.unstack(tensor, 0));
};
TensorArray2.prototype.split = function(length, tensor) {
var _this2 = this;
if (tensor.dtype !== this.dtype) {
throw new Error("TensorArray dtype is " + this.dtype + " but tensor has dtype " + tensor.dtype);
}
var totalLength = 0;
var cumulativeLengths = length.map(function(len) {
totalLength += len;
return totalLength;
});
if (totalLength !== tensor.shape[0]) {
throw new Error("Expected sum of lengths to be equal to\n tensor.shape[0], but sum of lengths is\n " + totalLength + ", and tensor's shape is: " + tensor.shape);
}
if (!this.dynamicSize && length.length !== this.maxSize) {
throw new Error("TensorArray's size is not equal to the size of lengths (" + this.maxSize + " vs. " + length.length + "), and the TensorArray is not marked as dynamically resizeable");
}
var elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
var tensors = [];
tfOps.tidy(function() {
tensor = tfOps.reshape(tensor, [1, totalLength, elementPerRow]);
for (var i2 = 0; i2 < length.length; ++i2) {
var previousLength = i2 === 0 ? 0 : cumulativeLengths[i2 - 1];
var indices_1 = [0, previousLength, 0];
var sizes = [1, length[i2], elementPerRow];
tensors[i2] = tfOps.reshape(tfOps.slice(tensor, indices_1, sizes), _this2.elementShape);
}
return tensors;
});
var indices = [];
for (var i = 0; i < length.length; i++) {
indices[i] = i;
}
this.writeMany(indices, tensors);
};
return TensorArray2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TensorList = function() {
function TensorList2(tensors, elementShape, elementDtype, maxNumElements) {
if (maxNumElements === void 0) {
maxNumElements = -1;
}
this.tensors = tensors;
this.elementShape = elementShape;
this.elementDtype = elementDtype;
if (tensors != null) {
tensors.forEach(function(tensor) {
if (elementDtype !== tensor.dtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + tensor.dtype);
}
assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, "TensorList shape mismatch: ");
tfOps.keep(tensor);
});
}
this.idTensor = tfOps.scalar(0);
this.maxNumElements = maxNumElements;
tfOps.keep(this.idTensor);
}
Object.defineProperty(TensorList2.prototype, "id", {
get: function() {
return this.idTensor.id;
},
enumerable: true,
configurable: true
});
TensorList2.prototype.copy = function() {
return new TensorList2(__spread(this.tensors), this.elementShape, this.elementDtype);
};
TensorList2.prototype.clearAndClose = function(keepIds) {
this.tensors.forEach(function(tensor) {
if (keepIds == null || !keepIds.has(tensor.id)) {
tensor.dispose();
}
});
this.tensors.length = 0;
this.idTensor.dispose();
};
TensorList2.prototype.size = function() {
return this.tensors.length;
};
TensorList2.prototype.stack = function(elementShape, elementDtype, numElements) {
var _this2 = this;
if (numElements === void 0) {
numElements = -1;
}
if (elementDtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + this.elementDtype);
}
if (numElements !== -1 && this.tensors.length !== numElements) {
throw new Error("Operation expected a list with " + numElements + " elements but got a list with " + this.tensors.length + " elements.");
}
assertShapesMatchAllowUndefinedSize(elementShape, this.elementShape, "TensorList shape mismatch: ");
return tfOps.tidy(function() {
var reshapedTensors = _this2.tensors.map(function(tensor) {
return tfOps.reshape(tensor, elementShape);
});
return tfOps.stack(reshapedTensors, 0);
});
};
TensorList2.prototype.popBack = function(elementShape, elementDtype) {
if (elementDtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + this.elementDtype);
}
if (this.size() === 0) {
throw new Error("Trying to pop from an empty list.");
}
var tensor = this.tensors.pop();
assertShapesMatchAllowUndefinedSize(tensor.shape, elementShape, "TensorList shape mismatch: ");
return tfOps.reshape(tensor, elementShape);
};
TensorList2.prototype.pushBack = function(tensor) {
if (tensor.dtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + tensor.dtype + ", but list elements " + this.elementDtype);
}
assertShapesMatchAllowUndefinedSize(tensor.shape, this.elementShape, "TensorList shape mismatch: ");
if (this.maxNumElements === this.size()) {
throw new Error("Trying to push element into a full list.");
}
tfOps.keep(tensor);
this.tensors.push(tensor);
};
TensorList2.prototype.resize = function(size) {
if (size < 0) {
throw new Error("TensorListResize expects size to be non-negative. Got: " + size);
}
if (this.maxNumElements !== -1 && size > this.maxNumElements) {
throw new Error("TensorListResize input size " + size + " is greater maxNumElement " + this.maxNumElements + ".");
}
this.tensors.length = size;
};
TensorList2.prototype.getItem = function(elementIndex, elementShape, elementDtype) {
if (elementDtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + this.elementDtype);
}
if (elementIndex < 0 || elementIndex > this.tensors.length) {
throw new Error("Trying to access element " + elementIndex + " in a list with " + this.tensors.length + " elements.");
}
if (this.tensors[elementIndex] == null) {
throw new Error("element at index " + elementIndex + " is null.");
}
assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape, elementShape, "TensorList shape mismatch: ");
return this.tensors[elementIndex];
};
TensorList2.prototype.setItem = function(elementIndex, tensor) {
if (tensor.dtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + tensor.dtype + ", but list elements " + this.elementDtype);
}
if (elementIndex < 0 || this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {
throw new Error("Trying to set element " + elementIndex + " in a list with max " + this.maxNumElements + " elements.");
}
assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, "TensorList shape mismatch: ");
tfOps.keep(tensor);
this.tensors[elementIndex] = tensor;
};
TensorList2.prototype.gather = function(indices, elementDtype, elementShape) {
var _this2 = this;
if (elementDtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + this.elementDtype);
}
assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, "TensorList shape mismatch: ");
indices = indices.slice(0, this.size());
if (indices.length === 0) {
return tfOps.tensor([], [0].concat(this.elementShape));
}
return tfOps.tidy(function() {
var tensors = indices.map(function(i) {
return tfOps.reshape(_this2.tensors[i], elementShape);
});
return tfOps.stack(tensors, 0);
});
};
TensorList2.prototype.concat = function(elementDtype, elementShape) {
var _this2 = this;
if (!!elementDtype && elementDtype !== this.elementDtype) {
throw new Error("TensorList dtype is " + this.elementDtype + " but concat requested dtype " + elementDtype);
}
assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, "TensorList shape mismatch: ");
if (this.size() === 0) {
return tfOps.tensor([], [0].concat(this.elementShape));
}
return tfOps.tidy(function() {
var tensors = _this2.tensors.map(function(t) {
return tfOps.reshape(t, elementShape);
});
return tfOps.concat(tensors, 0);
});
};
return TensorList2;
}();
function fromTensor(tensor, elementShape, elementDtype) {
var dtype = tensor.dtype;
if (tensor.shape.length < 1) {
throw new Error("Tensor must be at least a vector, but saw shape: " + tensor.shape);
}
if (tensor.dtype !== elementDtype) {
throw new Error("Invalid data types; op elements " + tensor.dtype + ", but list elements " + elementDtype);
}
var outputShape = tensor.shape.slice(1);
assertShapesMatchAllowUndefinedSize(outputShape, elementShape, "TensorList shape mismatch: ");
var tensorList = tfOps.unstack(tensor);
return new TensorList(tensorList, elementShape, dtype);
}
function reserve(elementShape, elementDtype, numElements) {
return new TensorList([], elementShape, elementDtype, numElements);
}
function scatter(tensor, indices, elementShape, numElements) {
if (indices.length !== tensor.shape[0]) {
throw new Error("Expected len(indices) == tensor.shape[0], but saw: " + indices.length + " vs. " + tensor.shape[0]);
}
var maxIndex = Math.max.apply(Math, __spread(indices));
if (numElements != null && numElements !== -1 && maxIndex >= numElements) {
throw new Error("Max index must be < array size (" + maxIndex + " vs. " + numElements + ")");
}
var list = new TensorList([], elementShape, tensor.dtype, numElements);
var tensors = tfOps.unstack(tensor, 0);
indices.forEach(function(value, index) {
list.setItem(value, tensors[index]);
});
return list;
}
function split(tensor, length, elementShape) {
var totalLength = 0;
var cumulativeLengths = length.map(function(len) {
totalLength += len;
return totalLength;
});
if (totalLength !== tensor.shape[0]) {
throw new Error("Expected sum of lengths to be equal to\n tensor.shape[0], but sum of lengths is\n " + totalLength + ", and tensor's shape is: " + tensor.shape);
}
var elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
var tensors = tfOps.tidy(function() {
var tensors2 = [];
tensor = tfOps.reshape(tensor, [1, totalLength, elementPerRow]);
for (var i2 = 0; i2 < length.length; ++i2) {
var previousLength = i2 === 0 ? 0 : cumulativeLengths[i2 - 1];
var indices = [0, previousLength, 0];
var sizes = [1, length[i2], elementPerRow];
tensors2[i2] = tfOps.reshape(tfOps.slice(tensor, indices, sizes), elementShape);
}
tensor.dispose();
return tensors2;
});
var list = new TensorList([], elementShape, tensor.dtype, length.length);
for (var i = 0; i < tensors.length; i++) {
list.setItem(i, tensors[i]);
}
return list;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var _this = void 0;
var executeOp$2 = function(node, tensorMap, context) {
return __awaiter(_this, void 0, void 0, function() {
var _a, thenFunc, elseFunc, cond, args, condValue, bodyFunc, condFunc, args, condResult, argIds_1, condValue, result, _loop_1, pred, pred, data, inputName, data, frameId, data, data, data, size, dtype, elementShape, dynamicSize, clearAfterRead, identicalElementShapes, name_1, tensorArray, id, index, writeTensor, writeTensorArray, readId, readIndex, readTensorArray, gatherId, gatherIndices, gatherDtype, gatherTensorArray, scatterId, scatterIndices, scatterTensor, scatterTensorArray, concatId, concatTensorArray, concatDtype, splitId, splitTensor, lengths, splitTensorArray, sizeId, sizeTensorArray, closeId, closeTensorArray, idTensor, index, writeTensor, tensorList, idTensor, readIndex, elementShape, elementDType, tensorList, scatterIndices, scatterTensor, elementShape, numElements, tensorList, elementShape, elementDtype, numElements, tensorList, gatherId, gatherIndices, elementShape, elementDtype, tensorList, idTensor, elementShape, elementDtype, numElements, tensorList, tensor, elementShape, elementDtype, tensorList, concatId, tensorList, concatDtype, elementShape, idTensor, writeTensor, tensorList, idTensor, elementShape, elementDType, tensorList, splitTensor, elementShape, lengths, tensorList;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
_a = node.op;
switch (_a) {
case "If":
return [3, 1];
case "StatelessIf":
return [3, 1];
case "While":
return [3, 3];
case "StatelessWhile":
return [3, 3];
case "LoopCond":
return [3, 9];
case "Switch":
return [3, 10];
case "Merge":
return [3, 12];
case "Enter":
return [3, 13];
case "Exit":
return [3, 14];
case "NextIteration":
return [3, 15];
case "TensorArrayV3":
return [3, 16];
case "TensorArrayWriteV3":
return [3, 17];
case "TensorArrayReadV3":
return [3, 18];
case "TensorArrayGatherV3":
return [3, 19];
case "TensorArrayScatterV3":
return [3, 20];
case "TensorArrayConcatV3":
return [3, 21];
case "TensorArraySplitV3":
return [3, 22];
case "TensorArraySizeV3":
return [3, 23];
case "TensorArrayCloseV3":
return [3, 24];
case "TensorListSetItem":
return [3, 25];
case "TensorListGetItem":
return [3, 26];
case "TensorListScatterV2":
return [3, 27];
case "TensorListScatter":
return [3, 27];
case "TensorListReserve":
return [3, 28];
case "TensorListGather":
return [3, 29];
case "TensorListStack":
return [3, 30];
case "TensorListFromTensor":
return [3, 31];
case "TensorListConcat":
return [3, 32];
case "TensorListPushBack":
return [3, 33];
case "TensorListPopBack":
return [3, 34];
case "TensorListSplit":
return [3, 35];
}
return [3, 36];
case 1:
thenFunc = getParamValue("thenBranch", node, tensorMap, context);
elseFunc = getParamValue("elseBranch", node, tensorMap, context);
cond = getParamValue("cond", node, tensorMap, context);
args = getParamValue("args", node, tensorMap, context);
return [4, cond.data()];
case 2:
condValue = _b.sent();
if (condValue[0]) {
return [2, context.functionMap[thenFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap)];
} else {
return [2, context.functionMap[elseFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap)];
}
case 3:
bodyFunc = getParamValue("body", node, tensorMap, context);
condFunc = getParamValue("cond", node, tensorMap, context);
args = getParamValue("args", node, tensorMap, context);
return [4, context.functionMap[condFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap)];
case 4:
condResult = _b.sent();
argIds_1 = args.map(function(tensor2) {
return tensor2.id;
});
return [4, condResult[0].data()];
case 5:
condValue = _b.sent();
condResult.forEach(function(tensor2) {
if (!tensor2.kept && argIds_1.indexOf(tensor2.id) === -1) {
tensor2.dispose();
}
});
result = args;
_loop_1 = function() {
var origResult, resultIds, condResult_1;
return __generator(this, function(_a2) {
switch (_a2.label) {
case 0:
origResult = result;
return [4, context.functionMap[bodyFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap)];
case 1:
result = _a2.sent();
resultIds = result.map(function(tensor2) {
return tensor2.id;
});
origResult.forEach(function(tensor2) {
if (!tensor2.kept && argIds_1.indexOf(tensor2.id) === -1 && resultIds.indexOf(tensor2.id) === -1) {
tensor2.dispose();
}
});
return [4, context.functionMap[condFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap)];
case 2:
condResult_1 = _a2.sent();
return [4, condResult_1[0].data()];
case 3:
condValue = _a2.sent();
condResult_1.forEach(function(tensor2) {
if (!tensor2.kept && argIds_1.indexOf(tensor2.id) === -1 && resultIds.indexOf(tensor2.id) === -1) {
tensor2.dispose();
}
});
return [2];
}
});
};
_b.label = 6;
case 6:
if (!condValue[0])
return [3, 8];
return [5, _loop_1()];
case 7:
_b.sent();
return [3, 6];
case 8:
return [2, result];
case 9: {
pred = getParamValue("pred", node, tensorMap, context);
return [2, [cloneTensor(pred)]];
}
case 10:
pred = getParamValue("pred", node, tensorMap, context);
data = getParamValue("data", node, tensorMap, context);
if (!data.kept) {
data = cloneTensor(data);
}
return [4, pred.data()];
case 11:
return [2, _b.sent()[0] ? [void 0, data] : [data, void 0]];
case 12: {
inputName = node.inputNames.find(function(name) {
return getTensor(name, tensorMap, context) !== void 0;
});
if (inputName) {
data = getTensor(inputName, tensorMap, context);
return [2, [cloneTensor(data)]];
}
return [2, void 0];
}
case 13: {
frameId = getParamValue("frameName", node, tensorMap, context);
data = getParamValue("tensor", node, tensorMap, context);
context.enterFrame(frameId);
return [2, [cloneTensor(data)]];
}
case 14: {
data = getParamValue("tensor", node, tensorMap, context);
context.exitFrame();
return [2, [cloneTensor(data)]];
}
case 15: {
data = getParamValue("tensor", node, tensorMap, context);
context.nextIteration();
return [2, [cloneTensor(data)]];
}
case 16: {
size = getParamValue("size", node, tensorMap, context);
dtype = getParamValue("dtype", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
dynamicSize = getParamValue("dynamicSize", node, tensorMap, context);
clearAfterRead = getParamValue("clearAfterRead", node, tensorMap, context);
identicalElementShapes = getParamValue("identicalElementShapes", node, tensorMap, context);
name_1 = getParamValue("name", node, tensorMap, context);
tensorArray = new TensorArray(name_1, dtype, size, elementShape, identicalElementShapes, dynamicSize, clearAfterRead);
context.addTensorArray(tensorArray);
return [2, [tensorArray.idTensor, tfOps.scalar(1)]];
}
case 17: {
id = getParamValue("tensorArrayId", node, tensorMap, context);
index = getParamValue("index", node, tensorMap, context);
writeTensor = getParamValue("tensor", node, tensorMap, context);
writeTensorArray = context.getTensorArray(id.id);
writeTensorArray.write(index, writeTensor);
return [2, [writeTensorArray.idTensor]];
}
case 18: {
readId = getParamValue("tensorArrayId", node, tensorMap, context);
readIndex = getParamValue("index", node, tensorMap, context);
readTensorArray = context.getTensorArray(readId.id);
return [2, [readTensorArray.read(readIndex)]];
}
case 19: {
gatherId = getParamValue("tensorArrayId", node, tensorMap, context);
gatherIndices = getParamValue("indices", node, tensorMap, context);
gatherDtype = getParamValue("dtype", node, tensorMap, context);
gatherTensorArray = context.getTensorArray(gatherId.id);
return [2, [gatherTensorArray.gather(gatherIndices, gatherDtype)]];
}
case 20: {
scatterId = getParamValue("tensorArrayId", node, tensorMap, context);
scatterIndices = getParamValue("indices", node, tensorMap, context);
scatterTensor = getParamValue("tensor", node, tensorMap, context);
scatterTensorArray = context.getTensorArray(scatterId.id);
scatterTensorArray.scatter(scatterIndices, scatterTensor);
return [2, [scatterTensorArray.idTensor]];
}
case 21: {
concatId = getParamValue("tensorArrayId", node, tensorMap, context);
concatTensorArray = context.getTensorArray(concatId.id);
concatDtype = getParamValue("dtype", node, tensorMap, context);
return [2, [concatTensorArray.concat(concatDtype)]];
}
case 22: {
splitId = getParamValue("tensorArrayId", node, tensorMap, context);
splitTensor = getParamValue("tensor", node, tensorMap, context);
lengths = getParamValue("lengths", node, tensorMap, context);
splitTensorArray = context.getTensorArray(splitId.id);
splitTensorArray.split(lengths, splitTensor);
return [2, [splitTensorArray.idTensor]];
}
case 23: {
sizeId = getParamValue("tensorArrayId", node, tensorMap, context);
sizeTensorArray = context.getTensorArray(sizeId.id);
return [2, [tfOps.scalar(sizeTensorArray.size(), "int32")]];
}
case 24: {
closeId = getParamValue("tensorArrayId", node, tensorMap, context);
closeTensorArray = context.getTensorArray(closeId.id);
closeTensorArray.clearAndClose();
return [2, [closeTensorArray.idTensor]];
}
case 25: {
idTensor = getParamValue("tensorListId", node, tensorMap, context);
index = getParamValue("index", node, tensorMap, context);
writeTensor = getParamValue("tensor", node, tensorMap, context);
tensorList = context.getTensorList(idTensor.id);
tensorList.setItem(index, writeTensor);
return [2, [tensorList.idTensor]];
}
case 26: {
idTensor = getParamValue("tensorListId", node, tensorMap, context);
readIndex = getParamValue("index", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
elementDType = getParamValue("elementDType", node, tensorMap, context);
tensorList = context.getTensorList(idTensor.id);
return [2, [tensorList.getItem(readIndex, elementShape, elementDType)]];
}
case 27: {
scatterIndices = getParamValue("indices", node, tensorMap, context);
scatterTensor = getParamValue("tensor", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
numElements = getParamValue("numElements", node, tensorMap, context);
tensorList = scatter(scatterTensor, scatterIndices, elementShape, numElements);
context.addTensorList(tensorList);
return [2, [tensorList.idTensor]];
}
case 28: {
elementShape = getParamValue("elementShape", node, tensorMap, context);
elementDtype = getParamValue("elementDType", node, tensorMap, context);
numElements = getParamValue("numElements", node, tensorMap, context);
tensorList = reserve(elementShape, elementDtype, numElements);
context.addTensorList(tensorList);
return [2, [tensorList.idTensor]];
}
case 29: {
gatherId = getParamValue("tensorListId", node, tensorMap, context);
gatherIndices = getParamValue("indices", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
elementDtype = getParamValue("elementDType", node, tensorMap, context);
tensorList = context.getTensorList(gatherId.id);
return [2, [tensorList.gather(gatherIndices, elementDtype, elementShape)]];
}
case 30: {
idTensor = getParamValue("tensorListId", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
elementDtype = getParamValue("elementDType", node, tensorMap, context);
numElements = getParamValue("numElements", node, tensorMap, context);
tensorList = context.getTensorList(idTensor.id);
return [2, [tensorList.stack(elementShape, elementDtype, numElements)]];
}
case 31: {
tensor = getParamValue("tensor", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
elementDtype = getParamValue("elementDType", node, tensorMap, context);
tensorList = fromTensor(tensor, elementShape, elementDtype);
context.addTensorList(tensorList);
return [2, [tensorList.idTensor]];
}
case 32: {
concatId = getParamValue("tensorListId", node, tensorMap, context);
tensorList = context.getTensorList(concatId.id);
concatDtype = getParamValue("dtype", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
return [2, [tensorList.concat(concatDtype, elementShape)]];
}
case 33: {
idTensor = getParamValue("tensorListId", node, tensorMap, context);
writeTensor = getParamValue("tensor", node, tensorMap, context);
tensorList = context.getTensorList(idTensor.id);
tensorList.pushBack(writeTensor);
return [2, [tensorList.idTensor]];
}
case 34: {
idTensor = getParamValue("tensorListId", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
elementDType = getParamValue("elementDType", node, tensorMap, context);
tensorList = context.getTensorList(idTensor.id);
return [2, [tensorList.popBack(elementShape, elementDType)]];
}
case 35: {
splitTensor = getParamValue("tensor", node, tensorMap, context);
elementShape = getParamValue("elementShape", node, tensorMap, context);
lengths = getParamValue("lengths", node, tensorMap, context);
tensorList = split(splitTensor, lengths, elementShape);
context.addTensorList(tensorList);
return [2, [tensorList.idTensor]];
}
case 36:
throw TypeError("Node type " + node.op + " is not implemented");
}
});
});
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedConvAndDepthWiseParams(node, tensorMap, context) {
var _a = __read(getParamValue("fusedOps", node, tensorMap, context), 2), extraOp = _a[0], activationFunc = _a[1];
var isBiasAdd = extraOp === "biasadd";
var isPrelu = activationFunc === "prelu";
var isBatchNorm = extraOp === "fusedbatchnorm";
var numArgs = getParamValue("numArgs", node, tensorMap, context);
if (isBiasAdd) {
if (isPrelu && numArgs !== 2) {
throw new Error("FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu must have two extra arguments: bias and alpha.");
}
if (!isPrelu && numArgs !== 1) {
throw new Error("FusedConv2d and DepthwiseConv2d with BiasAdd must have one extra argument: bias.");
}
}
if (isBatchNorm) {
throw new Error("FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported.");
}
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getPadding(node, tensorMap, context);
var dataFormat = getParamValue("dataFormat", node, tensorMap, context).toUpperCase();
var dilations = getParamValue("dilations", node, tensorMap, context);
var _b = __read(getParamValue("args", node, tensorMap, context), 2), biasArg = _b[0], preluArg = _b[1];
return {
stride,
pad,
dataFormat,
dilations,
biasArg,
preluArg,
activationFunc
};
}
var executeOp$3 = function(node, tensorMap, context) {
switch (node.op) {
case "Conv1D": {
var stride = getParamValue("stride", node, tensorMap, context);
var pad = getParamValue("pad", node, tensorMap, context);
var dataFormat = getParamValue("dataFormat", node, tensorMap, context).toUpperCase();
var dilation = getParamValue("dilation", node, tensorMap, context);
return [tfOps.conv1d(getParamValue("x", node, tensorMap, context), getParamValue("filter", node, tensorMap, context), stride, pad, dataFormat, dilation)];
}
case "Conv2D": {
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getPadding(node, tensorMap, context);
var dataFormat = getParamValue("dataFormat", node, tensorMap, context).toUpperCase();
var dilations = getParamValue("dilations", node, tensorMap, context);
return [tfOps.conv2d(getParamValue("x", node, tensorMap, context), getParamValue("filter", node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
}
case "_FusedConv2D": {
var _a = fusedConvAndDepthWiseParams(node, tensorMap, context), stride = _a.stride, pad = _a.pad, dataFormat = _a.dataFormat, dilations = _a.dilations, biasArg = _a.biasArg, preluArg = _a.preluArg, activationFunc = _a.activationFunc;
return [tfOps.fused.conv2d({
x: getParamValue("x", node, tensorMap, context),
filter: getParamValue("filter", node, tensorMap, context),
strides: [stride[1], stride[2]],
pad,
dataFormat,
dilations: [dilations[1], dilations[2]],
bias: biasArg,
activation: activationFunc,
preluActivationWeights: preluArg
})];
}
case "FusedDepthwiseConv2dNative": {
var _b = fusedConvAndDepthWiseParams(node, tensorMap, context), stride = _b.stride, pad = _b.pad, dataFormat = _b.dataFormat, dilations = _b.dilations, biasArg = _b.biasArg, preluArg = _b.preluArg, activationFunc = _b.activationFunc;
return [tfOps.fused.depthwiseConv2d({
x: getParamValue("x", node, tensorMap, context),
filter: getParamValue("filter", node, tensorMap, context),
strides: [stride[1], stride[2]],
pad,
dataFormat,
dilations: [dilations[1], dilations[2]],
bias: biasArg,
activation: activationFunc,
preluActivationWeights: preluArg
})];
}
case "Conv2DBackpropInput":
case "Conv2dTranspose": {
var shape = getParamValue("outputShape", node, tensorMap, context);
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getPadding(node, tensorMap, context);
return [tfOps.conv2dTranspose(getParamValue("x", node, tensorMap, context), getParamValue("filter", node, tensorMap, context), shape, [stride[1], stride[2]], pad)];
}
case "DepthwiseConv2dNative":
case "DepthwiseConv2d": {
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getPadding(node, tensorMap, context);
var dilations = getParamValue("dilations", node, tensorMap, context);
var dataFormat = getParamValue("dataFormat", node, tensorMap, context).toUpperCase();
return [tfOps.depthwiseConv2d(getParamValue("input", node, tensorMap, context), getParamValue("filter", node, tensorMap, context), [stride[1], stride[2]], pad, dataFormat, [dilations[1], dilations[2]])];
}
case "Conv3D": {
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getParamValue("pad", node, tensorMap, context);
var dataFormat = getParamValue("dataFormat", node, tensorMap, context).toUpperCase();
var dilations = getParamValue("dilations", node, tensorMap, context);
return [tfOps.conv3d(getParamValue("x", node, tensorMap, context), getParamValue("filter", node, tensorMap, context), [stride[1], stride[2], stride[3]], pad, dataFormat, [dilations[1], dilations[2], dilations[3]])];
}
case "AvgPool": {
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getParamValue("pad", node, tensorMap, context);
var kernelSize = getParamValue("kernelSize", node, tensorMap, context);
return [tfOps.avgPool(getParamValue("x", node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
}
case "MaxPool": {
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getParamValue("pad", node, tensorMap, context);
var kernelSize = getParamValue("kernelSize", node, tensorMap, context);
return [tfOps.maxPool(getParamValue("x", node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad)];
}
case "MaxPoolWithArgmax": {
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getParamValue("pad", node, tensorMap, context);
var kernelSize = getParamValue("kernelSize", node, tensorMap, context);
var includeBatchInIndex = getParamValue("includeBatchInIndex", node, tensorMap, context);
var _c = tfOps.maxPoolWithArgmax(getParamValue("x", node, tensorMap, context), [kernelSize[1], kernelSize[2]], [stride[1], stride[2]], pad, includeBatchInIndex), result = _c.result, indexes = _c.indexes;
return [result, indexes];
}
case "AvgPool3D": {
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getParamValue("pad", node, tensorMap, context);
var kernelSize = getParamValue("kernelSize", node, tensorMap, context);
return [tfOps.avgPool3d(getParamValue("x", node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
}
case "MaxPool3D": {
var stride = getParamValue("strides", node, tensorMap, context);
var pad = getParamValue("pad", node, tensorMap, context);
var kernelSize = getParamValue("kernelSize", node, tensorMap, context);
return [tfOps.maxPool3d(getParamValue("x", node, tensorMap, context), [kernelSize[1], kernelSize[2], kernelSize[3]], [stride[1], stride[2], stride[3]], pad)];
}
case "Dilation2D": {
var strides = getParamValue("strides", node, tensorMap, context);
var pad = getParamValue("pad", node, tensorMap, context);
var dilations = getParamValue("dilations", node, tensorMap, context);
var strideHeight = strides[1];
var strideWidth = strides[2];
var dilationHeight = dilations[1];
var dilationWidth = dilations[2];
return [tfOps.dilation2d(getParamValue("x", node, tensorMap, context), getParamValue("filter", node, tensorMap, context), [strideHeight, strideWidth], pad, [dilationHeight, dilationWidth], "NHWC")];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$4 = function(node, tensorMap, context) {
switch (node.op) {
case "Fill": {
var shape = getParamValue("shape", node, tensorMap, context);
var dtype = getParamValue("dtype", node, tensorMap, context);
var value = getParamValue("value", node, tensorMap, context);
return [tfOps.fill(shape, value, dtype)];
}
case "LinSpace": {
var start = getParamValue("start", node, tensorMap, context);
var stop_1 = getParamValue("stop", node, tensorMap, context);
var num = getParamValue("num", node, tensorMap, context);
return [tfOps.linspace(start, stop_1, num)];
}
case "Multinomial": {
var logits = getParamValue("logits", node, tensorMap, context);
var numSamples = getParamValue("numSamples", node, tensorMap, context);
var seed = getParamValue("seed", node, tensorMap, context);
return [tfOps.multinomial(logits, numSamples, seed)];
}
case "OneHot": {
var indices = getParamValue("indices", node, tensorMap, context);
var depth = getParamValue("depth", node, tensorMap, context);
var onValue = getParamValue("onValue", node, tensorMap, context);
var offValue = getParamValue("offValue", node, tensorMap, context);
return [tfOps.oneHot(indices, depth, onValue, offValue)];
}
case "Ones": {
return [tfOps.ones(getParamValue("shape", node, tensorMap, context), getParamValue("dtype", node, tensorMap, context))];
}
case "OnesLike": {
return [tfOps.onesLike(getParamValue("x", node, tensorMap, context))];
}
case "RandomUniform": {
return [tfOps.randomUniform(getParamValue("shape", node, tensorMap, context), getParamValue("minval", node, tensorMap, context), getParamValue("maxval", node, tensorMap, context), getParamValue("dtype", node, tensorMap, context))];
}
case "Range": {
var start = getParamValue("start", node, tensorMap, context);
var stop_2 = getParamValue("stop", node, tensorMap, context);
var step = getParamValue("step", node, tensorMap, context);
return [tfOps.range(start, stop_2, step, getParamValue("dtype", node, tensorMap, context))];
}
case "TruncatedNormal": {
var shape = getParamValue("shape", node, tensorMap, context);
var mean = getParamValue("mean", node, tensorMap, context);
var stdDev = getParamValue("stdDev", node, tensorMap, context);
var seed = getParamValue("seed", node, tensorMap, context);
return [tfOps.truncatedNormal(shape, mean, stdDev, getParamValue("dtype", node, tensorMap, context), seed)];
}
case "Zeros": {
return [tfOps.zeros(getParamValue("shape", node, tensorMap, context), getParamValue("dtype", node, tensorMap, context))];
}
case "ZerosLike": {
return [tfOps.zerosLike(getParamValue("x", node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var _this$1 = void 0;
function nmsParams(node, tensorMap, context) {
var boxes = getParamValue("boxes", node, tensorMap, context);
var scores = getParamValue("scores", node, tensorMap, context);
var maxOutputSize = getParamValue("maxOutputSize", node, tensorMap, context);
var iouThreshold = getParamValue("iouThreshold", node, tensorMap, context);
var scoreThreshold = getParamValue("scoreThreshold", node, tensorMap, context);
var softNmsSigma = getParamValue("softNmsSigma", node, tensorMap, context);
return {
boxes,
scores,
maxOutputSize,
iouThreshold,
scoreThreshold,
softNmsSigma
};
}
var executeOp$5 = function(node, tensorMap, context) {
return __awaiter(_this$1, void 0, void 0, function() {
var _a, _b, boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, result, _c, boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize, result, _d, boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, condition, result;
return __generator(this, function(_e) {
switch (_e.label) {
case 0:
_a = node.op;
switch (_a) {
case "NonMaxSuppressionV5":
return [3, 1];
case "NonMaxSuppressionV4":
return [3, 3];
case "NonMaxSuppressionV3":
return [3, 5];
case "NonMaxSuppressionV2":
return [3, 5];
case "Where":
return [3, 7];
case "ListDiff":
return [3, 9];
}
return [3, 10];
case 1:
_b = nmsParams(node, tensorMap, context), boxes = _b.boxes, scores = _b.scores, maxOutputSize = _b.maxOutputSize, iouThreshold = _b.iouThreshold, scoreThreshold = _b.scoreThreshold, softNmsSigma = _b.softNmsSigma;
return [4, tfOps.image.nonMaxSuppressionWithScoreAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma)];
case 2:
result = _e.sent();
return [2, [result.selectedIndices, result.selectedScores]];
case 3:
_c = nmsParams(node, tensorMap, context), boxes = _c.boxes, scores = _c.scores, maxOutputSize = _c.maxOutputSize, iouThreshold = _c.iouThreshold, scoreThreshold = _c.scoreThreshold;
padToMaxOutputSize = getParamValue("padToMaxOutputSize", node, tensorMap, context);
return [4, tfOps.image.nonMaxSuppressionPaddedAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize)];
case 4:
result = _e.sent();
return [2, [result.selectedIndices, result.validOutputs]];
case 5:
_d = nmsParams(node, tensorMap, context), boxes = _d.boxes, scores = _d.scores, maxOutputSize = _d.maxOutputSize, iouThreshold = _d.iouThreshold, scoreThreshold = _d.scoreThreshold;
return [4, tfOps.image.nonMaxSuppressionAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)];
case 6:
return [2, [_e.sent()]];
case 7:
condition = tfOps.cast(getParamValue("condition", node, tensorMap, context), "bool");
return [4, tfOps.whereAsync(condition)];
case 8:
result = [_e.sent()];
condition.dispose();
return [2, result];
case 9: {
return [2, tfOps.setdiff1dAsync(getParamValue("x", node, tensorMap, context), getParamValue("y", node, tensorMap, context))];
}
case 10:
throw TypeError("Node type " + node.op + " is not implemented");
}
});
});
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$6 = function(node, tensorMap, context) {
switch (node.op) {
case "TopKV2": {
var x = getParamValue("x", node, tensorMap, context);
var k = getParamValue("k", node, tensorMap, context);
var sorted = getParamValue("sorted", node, tensorMap, context);
var result = tfOps.topk(x, k, sorted);
return [result.values, result.indices];
}
case "Unique": {
var x = getParamValue("x", node, tensorMap, context);
var result = tfOps.unique(x);
return [result.values, result.indices];
}
case "UniqueV2": {
var x = getParamValue("x", node, tensorMap, context);
var axis = getParamValue("axis", node, tensorMap, context);
var result = tfOps.unique(x, axis);
return [result.values, result.indices];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$7 = function(node, tensorMap, context) {
switch (node.op) {
case "Const": {
return tensorMap[node.name];
}
case "PlaceholderWithDefault":
var def = getParamValue("default", node, tensorMap, context);
return [getTensor(node.name, tensorMap, context) || def];
case "Placeholder":
return [getTensor(node.name, tensorMap, context)];
case "Identity":
case "StopGradient":
case "FakeQuantWithMinMaxVars": {
var data_1 = getParamValue("x", node, tensorMap, context);
return [cloneTensor(data_1)];
}
case "IdentityN":
return getParamValue("x", node, tensorMap, context).map(function(t) {
return cloneTensor(t);
});
case "Snapshot":
var snapshot = getParamValue("x", node, tensorMap, context);
return [cloneTensor(snapshot)];
case "Shape":
return [tfOps.tensor1d(getParamValue("x", node, tensorMap, context).shape, "int32")];
case "ShapeN":
return getParamValue("x", node, tensorMap, context).map(function(t) {
return tfOps.tensor1d(t.shape);
});
case "Size":
return [tfOps.scalar(getParamValue("x", node, tensorMap, context).size, "int32")];
case "Rank":
return [tfOps.scalar(getParamValue("x", node, tensorMap, context).rank, "int32")];
case "NoOp":
return [tfOps.scalar(1)];
case "Print":
var input = getParamValue("x", node, tensorMap, context);
var data = getParamValue("data", node, tensorMap, context);
var message = getParamValue("message", node, tensorMap, context);
var summarize = getParamValue("summarize", node, tensorMap, context);
console.warn("The graph has a tf.print() operation,usually used for debugging, which slows down performance.");
console.log(message);
for (var i = 0; i < data.length; i++) {
console.log(Array.prototype.slice.call(data[i].dataSync()).slice(0, summarize));
}
return [input];
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var HashTable = function() {
function HashTable2(keyDType, valueDType) {
this.keyDType = keyDType;
this.valueDType = valueDType;
this.handle = tfOps.scalar(0);
this.tensorMap = new Map();
tfOps.keep(this.handle);
}
Object.defineProperty(HashTable2.prototype, "id", {
get: function() {
return this.handle.id;
},
enumerable: true,
configurable: true
});
HashTable2.prototype.clearAndClose = function() {
this.tensorMap.forEach(function(value) {
return value.dispose();
});
this.tensorMap.clear();
this.handle.dispose();
};
HashTable2.prototype.size = function() {
return this.tensorMap.size;
};
HashTable2.prototype.import = function(keys, values) {
return __awaiter(this, void 0, void 0, function() {
var $keys;
var _this2 = this;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
this.checkKeyAndValueTensor(keys, values);
return [4, keys.data()];
case 1:
$keys = _a.sent();
this.tensorMap.forEach(function(value) {
return value.dispose();
});
this.tensorMap.clear();
return [2, tfOps.tidy(function() {
var $values = tfOps.unstack(values);
var keysLength = $keys.length;
var valuesLength = $values.length;
tfOps.util.assert(keysLength === valuesLength, function() {
return "The number of elements doesn't match, keys has " + (keysLength + " elements, the values has " + valuesLength + " ") + "elements.";
});
for (var i = 0; i < keysLength; i++) {
var key = $keys[i];
var value = $values[i];
tfOps.keep(value);
_this2.tensorMap.set(key, value);
}
return _this2.handle;
})];
}
});
});
};
HashTable2.prototype.find = function(keys, defaultValue) {
return __awaiter(this, void 0, void 0, function() {
var $keys;
var _this2 = this;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
this.checkKeyAndValueTensor(keys, defaultValue);
return [4, keys.data()];
case 1:
$keys = _a.sent();
return [2, tfOps.tidy(function() {
var result = [];
for (var i = 0; i < $keys.length; i++) {
var key = $keys[i];
var value = _this2.findWithDefault(key, defaultValue);
result.push(value);
}
return tfOps.stack(result);
})];
}
});
});
};
HashTable2.prototype.findWithDefault = function(key, defaultValue) {
var result = this.tensorMap.get(key);
return result != null ? result : defaultValue;
};
HashTable2.prototype.checkKeyAndValueTensor = function(key, value) {
if (key.dtype !== this.keyDType) {
throw new Error("Expect key dtype " + this.keyDType + ", but got " + ("" + key.dtype));
}
if (value.dtype !== this.valueDType) {
throw new Error("Expect value dtype " + this.valueDType + ", but got " + ("" + value.dtype));
}
};
return HashTable2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var _this$2 = void 0;
var executeOp$8 = function(node, tensorMap, context, resourceManager) {
return __awaiter(_this$2, void 0, void 0, function() {
var _a, keyDType, valueDType, hashTable2, handle, keys, values, hashTable2, handle, keys, defaultValue, hashTable2;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
_a = node.op;
switch (_a) {
case "HashTable":
return [3, 1];
case "HashTableV2":
return [3, 1];
case "LookupTableImport":
return [3, 2];
case "LookupTableImportV2":
return [3, 2];
case "LookupTableFind":
return [3, 4];
case "LookupTableFindV2":
return [3, 4];
}
return [3, 6];
case 1: {
keyDType = getParamValue("keyDType", node, tensorMap, context);
valueDType = getParamValue("valueDType", node, tensorMap, context);
hashTable2 = new HashTable(keyDType, valueDType);
resourceManager.addHashTable(node.name, hashTable2);
return [2, [hashTable2.handle]];
}
case 2:
handle = getParamValue("tableHandle", node, tensorMap, context, resourceManager);
keys = getParamValue("keys", node, tensorMap, context);
values = getParamValue("values", node, tensorMap, context);
hashTable2 = resourceManager.getHashTableById(handle.id);
return [4, hashTable2.import(keys, values)];
case 3:
return [2, [_b.sent()]];
case 4:
handle = getParamValue("tableHandle", node, tensorMap, context, resourceManager);
keys = getParamValue("keys", node, tensorMap, context);
defaultValue = getParamValue("defaultValue", node, tensorMap, context);
hashTable2 = resourceManager.getHashTableById(handle.id);
return [4, hashTable2.find(keys, defaultValue)];
case 5:
return [2, [_b.sent()]];
case 6:
throw TypeError("Node type " + node.op + " is not implemented");
}
});
});
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$9 = function(node, tensorMap, context) {
switch (node.op) {
case "ResizeBilinear": {
var images = getParamValue("images", node, tensorMap, context);
var size = getParamValue("size", node, tensorMap, context);
var alignCorners = getParamValue("alignCorners", node, tensorMap, context);
return [tfOps.image.resizeBilinear(images, [size[0], size[1]], alignCorners)];
}
case "ResizeNearestNeighbor": {
var images = getParamValue("images", node, tensorMap, context);
var size = getParamValue("size", node, tensorMap, context);
var alignCorners = getParamValue("alignCorners", node, tensorMap, context);
return [tfOps.image.resizeNearestNeighbor(images, [size[0], size[1]], alignCorners)];
}
case "CropAndResize": {
var image2 = getParamValue("image", node, tensorMap, context);
var boxes = getParamValue("boxes", node, tensorMap, context);
var boxInd = getParamValue("boxInd", node, tensorMap, context);
var cropSize = getParamValue("cropSize", node, tensorMap, context);
var method = getParamValue("method", node, tensorMap, context);
var extrapolationValue = getParamValue("extrapolationValue", node, tensorMap, context);
return [tfOps.image.cropAndResize(image2, boxes, boxInd, cropSize, method, extrapolationValue)];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$a = function(node, tensorMap, context) {
switch (node.op) {
case "Equal": {
return [tfOps.equal(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "NotEqual": {
return [tfOps.notEqual(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "Greater": {
return [tfOps.greater(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "GreaterEqual": {
return [tfOps.greaterEqual(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "Less": {
return [tfOps.less(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "LessEqual": {
return [tfOps.lessEqual(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "LogicalAnd": {
return [tfOps.logicalAnd(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "LogicalNot": {
return [tfOps.logicalNot(getParamValue("a", node, tensorMap, context))];
}
case "LogicalOr": {
return [tfOps.logicalOr(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
case "Select":
case "SelectV2": {
return [tfOps.where(getParamValue("condition", node, tensorMap, context), getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$b = function(node, tensorMap, context) {
switch (node.op) {
case "BatchMatMul":
case "BatchMatMulV2":
case "MatMul":
return [tfOps.matMul(getParamValue("a", node, tensorMap, context), getParamValue("b", node, tensorMap, context), getParamValue("transposeA", node, tensorMap, context), getParamValue("transposeB", node, tensorMap, context))];
case "Transpose":
return [tfOps.transpose(getParamValue("x", node, tensorMap, context), getParamValue("perm", node, tensorMap, context))];
case "_FusedMatMul":
var _a = __read(getParamValue("fusedOps", node, tensorMap, context), 2), extraOp = _a[0], activationFunc = _a[1];
var isBiasAdd = extraOp === "biasadd";
var isPrelu = activationFunc === "prelu";
var numArgs = getParamValue("numArgs", node, tensorMap, context);
if (isBiasAdd) {
if (isPrelu && numArgs !== 2) {
throw new Error("Fused MatMul with BiasAdd and Prelu must have two extra arguments: bias and alpha.");
}
if (!isPrelu && numArgs !== 1) {
throw new Error("Fused MatMul with BiasAdd must have one extra argument: bias.");
}
}
var _b = __read(getParamValue("args", node, tensorMap, context), 2), biasArg = _b[0], preluArg = _b[1];
return [tfOps.fused.matMul({
a: getParamValue("a", node, tensorMap, context),
b: getParamValue("b", node, tensorMap, context),
transposeA: getParamValue("transposeA", node, tensorMap, context),
transposeB: getParamValue("transposeB", node, tensorMap, context),
bias: biasArg,
activation: activationFunc,
preluActivationWeights: preluArg
})];
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$c = function(node, tensorMap, context) {
switch (node.op) {
case "FusedBatchNorm":
case "FusedBatchNormV2": {
return [tfOps.batchNorm(getParamValue("x", node, tensorMap, context), getParamValue("mean", node, tensorMap, context), getParamValue("variance", node, tensorMap, context), getParamValue("offset", node, tensorMap, context), getParamValue("scale", node, tensorMap, context), getParamValue("epsilon", node, tensorMap, context))];
}
case "FusedBatchNormV3": {
return [tfOps.batchNorm(getParamValue("x", node, tensorMap, context), getParamValue("mean", node, tensorMap, context), getParamValue("variance", node, tensorMap, context), getParamValue("offset", node, tensorMap, context), getParamValue("scale", node, tensorMap, context), getParamValue("epsilon", node, tensorMap, context))];
}
case "LRN": {
return [tfOps.localResponseNormalization(getParamValue("x", node, tensorMap, context), getParamValue("radius", node, tensorMap, context), getParamValue("bias", node, tensorMap, context), getParamValue("alpha", node, tensorMap, context), getParamValue("beta", node, tensorMap, context))];
}
case "Softmax": {
return [tfOps.softmax(getParamValue("x", node, tensorMap, context))];
}
case "LogSoftmax": {
return [tfOps.logSoftmax(getParamValue("x", node, tensorMap, context))];
}
case "SparseToDense": {
return [tfOps.sparseToDense(getParamValue("sparseIndices", node, tensorMap, context), getParamValue("outputShape", node, tensorMap, context), getParamValue("sparseValues", node, tensorMap, context), getParamValue("defaultValue", node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$d = function(node, tensorMap, context) {
switch (node.op) {
case "Max": {
var axis = getParamValue("axis", node, tensorMap, context);
var keepDims = getParamValue("keepDims", node, tensorMap, context);
return [tfOps.max(getParamValue("x", node, tensorMap, context), axis, keepDims)];
}
case "Mean": {
var axis = getParamValue("axis", node, tensorMap, context);
var keepDims = getParamValue("keepDims", node, tensorMap, context);
return [tfOps.mean(getParamValue("x", node, tensorMap, context), axis, keepDims)];
}
case "Min": {
var axis = getParamValue("axis", node, tensorMap, context);
var keepDims = getParamValue("keepDims", node, tensorMap, context);
return [tfOps.min(getParamValue("x", node, tensorMap, context), axis, keepDims)];
}
case "Sum": {
var axis = getParamValue("axis", node, tensorMap, context);
var keepDims = getParamValue("keepDims", node, tensorMap, context);
return [tfOps.sum(getParamValue("x", node, tensorMap, context), axis, keepDims)];
}
case "All": {
var axis = getParamValue("axis", node, tensorMap, context);
var keepDims = getParamValue("keepDims", node, tensorMap, context);
return [tfOps.all(getParamValue("x", node, tensorMap, context), axis, keepDims)];
}
case "Any": {
var axis = getParamValue("axis", node, tensorMap, context);
var keepDims = getParamValue("keepDims", node, tensorMap, context);
return [tfOps.any(getParamValue("x", node, tensorMap, context), axis, keepDims)];
}
case "ArgMax": {
var axis = getParamValue("axis", node, tensorMap, context);
return [tfOps.argMax(getParamValue("x", node, tensorMap, context), axis)];
}
case "ArgMin": {
var axis = getParamValue("axis", node, tensorMap, context);
return [tfOps.argMin(getParamValue("x", node, tensorMap, context), axis)];
}
case "Prod": {
var axis = getParamValue("axis", node, tensorMap, context);
var keepDims = getParamValue("keepDims", node, tensorMap, context);
return [tfOps.prod(getParamValue("x", node, tensorMap, context), axis, keepDims)];
}
case "Cumsum": {
var axis = getParamValue("axis", node, tensorMap, context);
var exclusive = getParamValue("exclusive", node, tensorMap, context);
var reverse = getParamValue("reverse", node, tensorMap, context);
return [tfOps.cumsum(getParamValue("x", node, tensorMap, context), axis, exclusive, reverse)];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$e = function(node, tensorMap, context) {
switch (node.op) {
case "ConcatV2":
case "Concat": {
var n = getParamValue("n", node, tensorMap, context);
var axis = getParamValue("axis", node, tensorMap, context);
var inputs = getParamValue("tensors", node, tensorMap, context);
inputs = inputs.slice(0, n);
return [tfOps.concat(inputs, axis)];
}
case "GatherV2":
case "Gather": {
var axis = getParamValue("axis", node, tensorMap, context);
var input = getParamValue("x", node, tensorMap, context);
var indices = getParamValue("indices", node, tensorMap, context);
return [tfOps.gather(input, tfOps.cast(indices, "int32"), axis)];
}
case "ReverseV2":
case "Reverse": {
var axis = getParamValue("axis", node, tensorMap, context);
var input = getParamValue("x", node, tensorMap, context);
return [tfOps.reverse(input, axis)];
}
case "Slice": {
var begin = getParamValue("begin", node, tensorMap, context);
var size = getParamValue("size", node, tensorMap, context);
return [tfOps.slice(getParamValue("x", node, tensorMap, context), begin, size)];
}
case "StridedSlice": {
var begin = getParamValue("begin", node, tensorMap, context);
var end = getParamValue("end", node, tensorMap, context);
var strides = getParamValue("strides", node, tensorMap, context);
var beginMask = getParamValue("beginMask", node, tensorMap, context);
var endMask = getParamValue("endMask", node, tensorMap, context);
var ellipsisMask = getParamValue("ellipsisMask", node, tensorMap, context);
var newAxisMask = getParamValue("newAxisMask", node, tensorMap, context);
var shrinkAxisMask = getParamValue("shrinkAxisMask", node, tensorMap, context);
var tensor = getParamValue("x", node, tensorMap, context);
return [tfOps.stridedSlice(tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)];
}
case "Pack": {
return tfOps.tidy(function() {
var axis2 = getParamValue("axis", node, tensorMap, context);
var tensors = getParamValue("tensors", node, tensorMap, context);
var shape2 = tensors[0].shape;
var squeezedShape = tfOps.squeeze(tensors[0]).shape;
var mapped = tensors.map(function(tensor2) {
var sameShape = tfOps.util.arraysEqual(tensor2.shape, shape2);
if (!sameShape && !tfOps.util.arraysEqual(tfOps.squeeze(tensor2).shape, squeezedShape)) {
throw new Error("the input tensors shape does not match");
}
return sameShape ? tensor2 : tfOps.reshape(tensor2, shape2);
});
return [tfOps.stack(mapped, axis2)];
});
}
case "Unpack": {
var axis = getParamValue("axis", node, tensorMap, context);
var tensor = getParamValue("tensor", node, tensorMap, context);
return tfOps.unstack(tensor, axis);
}
case "Tile": {
var reps = getParamValue("reps", node, tensorMap, context);
return [tfOps.tile(getParamValue("x", node, tensorMap, context), reps)];
}
case "Split":
case "SplitV": {
var axis = getParamValue("axis", node, tensorMap, context);
var numOrSizeSplits = getParamValue("numOrSizeSplits", node, tensorMap, context);
var tensor = getParamValue("x", node, tensorMap, context);
return tfOps.split(tensor, numOrSizeSplits, axis);
}
case "ScatterNd": {
var indices = getParamValue("indices", node, tensorMap, context);
var values = getParamValue("values", node, tensorMap, context);
var shape = getParamValue("shape", node, tensorMap, context);
return [tfOps.scatterND(indices, values, shape)];
}
case "GatherNd": {
var x = getParamValue("x", node, tensorMap, context);
var indices = getParamValue("indices", node, tensorMap, context);
return [tfOps.gatherND(x, indices)];
}
case "SparseToDense": {
var indices = getParamValue("sparseIndices", node, tensorMap, context);
var shape = getParamValue("outputShape", node, tensorMap, context);
var sparseValues = getParamValue("sparseValues", node, tensorMap, context);
var defaultValue = getParamValue("defaultValue", node, tensorMap, context);
return [tfOps.sparseToDense(indices, sparseValues, shape, sparseValues.dtype === defaultValue.dtype ? defaultValue : tfOps.cast(defaultValue, sparseValues.dtype))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$f = function(node, tensorMap, context) {
switch (node.op) {
case "FFT": {
return [tfOps.fft(getParamValue("x", node, tensorMap, context))];
}
case "IFFT": {
return [tfOps.ifft(getParamValue("x", node, tensorMap, context))];
}
case "RFFT": {
return [tfOps.rfft(getParamValue("x", node, tensorMap, context))];
}
case "IRFFT": {
return [tfOps.irfft(getParamValue("x", node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$g = function(node, tensorMap, context) {
switch (node.op) {
case "Cast": {
return [tfOps.cast(getParamValue("x", node, tensorMap, context), getParamValue("dtype", node, tensorMap, context))];
}
case "ExpandDims": {
var axis = getParamValue("axis", node, tensorMap, context);
return [tfOps.expandDims(getParamValue("x", node, tensorMap, context), axis)];
}
case "Squeeze": {
var axis = getParamValue("axis", node, tensorMap, context);
return [tfOps.squeeze(getParamValue("x", node, tensorMap, context), axis)];
}
case "Reshape": {
return [tfOps.reshape(getParamValue("x", node, tensorMap, context), getParamValue("shape", node, tensorMap, context))];
}
case "MirrorPad": {
return [tfOps.mirrorPad(getParamValue("x", node, tensorMap, context), getParamValue("padding", node, tensorMap, context), getParamValue("mode", node, tensorMap, context))];
}
case "PadV2":
case "Pad": {
return [tfOps.pad(getParamValue("x", node, tensorMap, context), getParamValue("padding", node, tensorMap, context), getParamValue("constantValue", node, tensorMap, context))];
}
case "SpaceToBatchND": {
var blockShape = getParamValue("blockShape", node, tensorMap, context);
var paddings = getParamValue("paddings", node, tensorMap, context);
return [tfOps.spaceToBatchND(getParamValue("x", node, tensorMap, context), blockShape, paddings)];
}
case "BatchToSpaceND": {
var blockShape = getParamValue("blockShape", node, tensorMap, context);
var crops = getParamValue("crops", node, tensorMap, context);
return [tfOps.batchToSpaceND(getParamValue("x", node, tensorMap, context), blockShape, crops)];
}
case "DepthToSpace": {
var blockSize = getParamValue("blockSize", node, tensorMap, context);
var dataFormat = getParamValue("dataFormat", node, tensorMap, context).toUpperCase();
return [tfOps.depthToSpace(getParamValue("x", node, tensorMap, context), blockSize, dataFormat)];
}
case "BroadcastTo": {
return [tfOps.broadcastTo(getParamValue("x", node, tensorMap, context), getParamValue("shape", node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function executeOp$h(node, tensorMap, context, resourceManager) {
var value = function(node2, tensorMap2, context2) {
switch (node2.category) {
case "arithmetic":
return tfOps.tidy(function() {
return executeOp(node2, tensorMap2, context2);
});
case "basic_math":
return tfOps.tidy(function() {
return executeOp$1(node2, tensorMap2, context2);
});
case "control":
return executeOp$2(node2, tensorMap2, context2);
case "convolution":
return tfOps.tidy(function() {
return executeOp$3(node2, tensorMap2, context2);
});
case "creation":
return tfOps.tidy(function() {
return executeOp$4(node2, tensorMap2, context2);
});
case "dynamic":
return executeOp$5(node2, tensorMap2, context2);
case "evaluation":
return tfOps.tidy(function() {
return executeOp$6(node2, tensorMap2, context2);
});
case "image":
return tfOps.tidy(function() {
return executeOp$9(node2, tensorMap2, context2);
});
case "graph":
return tfOps.tidy(function() {
return executeOp$7(node2, tensorMap2, context2);
});
case "logical":
return tfOps.tidy(function() {
return executeOp$a(node2, tensorMap2, context2);
});
case "matrices":
return tfOps.tidy(function() {
return executeOp$b(node2, tensorMap2, context2);
});
case "normalization":
return tfOps.tidy(function() {
return executeOp$c(node2, tensorMap2, context2);
});
case "reduction":
return tfOps.tidy(function() {
return executeOp$d(node2, tensorMap2, context2);
});
case "slice_join":
return tfOps.tidy(function() {
return executeOp$e(node2, tensorMap2, context2);
});
case "spectral":
return tfOps.tidy(function() {
return executeOp$f(node2, tensorMap2, context2);
});
case "transformation":
return tfOps.tidy(function() {
return executeOp$g(node2, tensorMap2, context2);
});
case "hash_table":
return executeOp$8(node2, tensorMap2, context2, resourceManager);
case "custom":
var opMapper = getRegisteredOp(node2.op);
if (opMapper && opMapper.customExecutor) {
return opMapper.customExecutor(new NodeValueImpl(node2, tensorMap2, context2));
} else {
throw TypeError("Custom op " + node2.op + " is not registered.");
}
default:
throw TypeError("Unknown op '" + node2.op + "'. File an issue at https://github.com/tensorflow/tfjs/issues so we can add it, or register a custom execution with tf.registerOp()");
}
}(node, tensorMap, context);
if (tfOps.util.isPromise(value)) {
return value.then(function(data) {
return [].concat(data);
});
}
return [].concat(value);
}
var ExecutionContext = function() {
function ExecutionContext2(weightMap, tensorArrayMap, tensorListMap, functionMap) {
if (weightMap === void 0) {
weightMap = {};
}
if (tensorArrayMap === void 0) {
tensorArrayMap = {};
}
if (tensorListMap === void 0) {
tensorListMap = {};
}
if (functionMap === void 0) {
functionMap = {};
}
this.weightMap = weightMap;
this.tensorArrayMap = tensorArrayMap;
this.tensorListMap = tensorListMap;
this.functionMap = functionMap;
this.rootContext = {id: 0, frameName: "", iterationId: 0};
this.contexts = [this.rootContext];
this.lastId = 0;
this.generateCurrentContextIds();
}
ExecutionContext2.prototype.newFrame = function(id, frameName) {
return {id, frameName, iterationId: 0};
};
Object.defineProperty(ExecutionContext2.prototype, "currentContext", {
get: function() {
return this.contexts;
},
set: function(contexts) {
if (this.contexts !== contexts) {
this.contexts = contexts;
this.generateCurrentContextIds();
}
},
enumerable: true,
configurable: true
});
Object.defineProperty(ExecutionContext2.prototype, "currentContextId", {
get: function() {
return this._currentContextIds[0];
},
enumerable: true,
configurable: true
});
Object.defineProperty(ExecutionContext2.prototype, "currentContextIds", {
get: function() {
return this._currentContextIds;
},
enumerable: true,
configurable: true
});
ExecutionContext2.prototype.generateCurrentContextIds = function() {
var names = [];
for (var i = 0; i < this.contexts.length - 1; i++) {
var contexts = this.contexts.slice(0, this.contexts.length - i);
names.push(this.contextIdforContexts(contexts));
}
names.push("");
this._currentContextIds = names;
};
ExecutionContext2.prototype.contextIdforContexts = function(contexts) {
return contexts ? contexts.map(function(context) {
return context.id === 0 && context.iterationId === 0 ? "" : context.frameName + "-" + context.iterationId;
}).join("/") : "";
};
ExecutionContext2.prototype.enterFrame = function(frameId) {
if (this.contexts) {
this.lastId++;
this.contexts = this.contexts.slice();
this.contexts.push(this.newFrame(this.lastId, frameId));
this._currentContextIds.unshift(this.contextIdforContexts(this.contexts));
}
};
ExecutionContext2.prototype.exitFrame = function() {
if (this.contexts && this.contexts.length > 1) {
this.contexts = this.contexts.slice();
this.contexts.splice(-1);
this.currentContextIds.shift();
} else {
throw new Error("Cannot exit frame, the context is empty");
}
};
ExecutionContext2.prototype.nextIteration = function() {
if (this.contexts && this.contexts.length > 0) {
this.contexts = this.contexts.slice();
this.lastId++;
var context = Object.assign({}, this.contexts[this.contexts.length - 1]);
context.iterationId += 1;
context.id = this.lastId;
this.contexts.splice(-1, 1, context);
this._currentContextIds.splice(0, 1, this.contextIdforContexts(this.contexts));
} else {
throw new Error("Cannot increase frame iteration, the context is empty");
}
};
ExecutionContext2.prototype.getWeight = function(name) {
return this.weightMap[name];
};
ExecutionContext2.prototype.addTensorArray = function(tensorArray) {
this.tensorArrayMap[tensorArray.id] = tensorArray;
};
ExecutionContext2.prototype.getTensorArray = function(id) {
return this.tensorArrayMap[id];
};
ExecutionContext2.prototype.addTensorList = function(tensorList) {
this.tensorListMap[tensorList.id] = tensorList;
};
ExecutionContext2.prototype.getTensorList = function(id) {
return this.tensorListMap[id];
};
ExecutionContext2.prototype.dispose = function(keepIds) {
for (var key in this.tensorArrayMap) {
this.tensorArrayMap[key].clearAndClose(keepIds);
}
for (var key in this.tensorListMap) {
this.tensorListMap[key].clearAndClose(keepIds);
}
};
return ExecutionContext2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getExecutionSubgraph(inputs, outputs, weightMap, initNodes) {
var usedNodes = new Set();
var missingInputs = [];
var dynamicNode = null;
var syncInputs = null;
var seen = new Set();
var inputNodeNames = Object.keys(inputs).map(function(name) {
return parseNodeName(name)[0];
});
var initNodeNames = [];
if (initNodes != null) {
initNodeNames = initNodes.map(function(node2) {
return parseNodeName(node2.name)[0];
});
}
var frontier = __spread(outputs);
while (frontier.length > 0) {
var node = frontier.pop();
if (isControlFlow(node) || isDynamicShape(node) || isHashTable(node)) {
if (dynamicNode == null) {
dynamicNode = node;
syncInputs = dynamicNode.children.map(function(child) {
return child.name;
}).filter(function(name) {
return usedNodes.has(name);
});
}
}
usedNodes.add(node.name);
if (weightMap[node.name] != null) {
continue;
}
if (inputNodeNames.indexOf(node.name) !== -1) {
continue;
}
if (initNodeNames.indexOf(node.name) !== -1) {
continue;
}
if (node.inputs.length === 0) {
missingInputs.push(node.name);
continue;
}
node.inputs.forEach(function(input) {
if (seen.has(input.name)) {
return;
}
seen.add(input.name);
frontier.push(input);
});
}
return {inputs, outputs, usedNodes, missingInputs, dynamicNode, syncInputs};
}
function getNodesInTopologicalOrder(graph2, weightMap, executionInfo) {
var usedNodes = executionInfo.usedNodes, inputs = executionInfo.inputs;
var frontier = [];
var inputNodes = Object.keys(inputs).map(function(name) {
return parseNodeName(name)[0];
}).map(function(name) {
return graph2.nodes[name];
});
var initNodes = graph2.initNodes;
inputNodes.forEach(function(input) {
if (usedNodes.has(input.name)) {
frontier.push(input);
}
});
graph2.weights.forEach(function(weight) {
if (usedNodes.has(weight.name)) {
frontier.push(weight);
}
});
if (initNodes != null) {
initNodes.forEach(function(node2) {
if (usedNodes.has(node2.name)) {
frontier.push(node2);
}
});
}
var seen = new Set();
var orderedNodes = [];
while (frontier.length > 0) {
var node = frontier.pop();
seen.add(node.name);
if (!weightMap[node.name]) {
orderedNodes.push(node);
}
node.children.forEach(function(child) {
if (!seen.has(child.name) && usedNodes.has(child.name) && child.inputs.every(function(input) {
return seen.has(input.name);
})) {
frontier.push(child);
}
});
}
return orderedNodes;
}
var CONTROL_FLOW_OPS = [
"Switch",
"Merge",
"Enter",
"Exit",
"NextIteration",
"StatelessIf",
"StatelessWhile",
"if",
"While"
];
var DYNAMIC_SHAPE_OPS = [
"NonMaxSuppressionV2",
"NonMaxSuppressionV3",
"NonMaxSuppressionV5",
"Where"
];
var HASH_TABLE_OPS = [
"HashTable",
"HashTableV2",
"LookupTableImport",
"LookupTableImportV2",
"LookupTableFind",
"LookupTableFindV2"
];
function isControlFlow(node) {
return CONTROL_FLOW_OPS.indexOf(node.op) >= 0;
}
function isDynamicShape(node) {
return DYNAMIC_SHAPE_OPS.indexOf(node.op) >= 0;
}
function isHashTable(node) {
return HASH_TABLE_OPS.indexOf(node.op) >= 0;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var GraphExecutor = function() {
function GraphExecutor2(graph2, parent) {
var _this2 = this;
this.graph = graph2;
this.parent = parent;
this.compiledMap = new Map();
this._weightMap = {};
this.SEPERATOR = ",";
this._functions = {};
this._functionExecutorMap = {};
this._outputs = graph2.outputs;
this._inputs = graph2.inputs;
this._initNodes = graph2.initNodes;
this._signature = graph2.signature;
this._functions = graph2.functions;
if (graph2.functions != null) {
Object.keys(graph2.functions).forEach(function(name) {
_this2._functionExecutorMap[name] = new GraphExecutor2(graph2.functions[name], _this2);
});
}
}
Object.defineProperty(GraphExecutor2.prototype, "weightIds", {
get: function() {
return this.parent ? this.parent.weightIds : this._weightIds;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphExecutor2.prototype, "functionExecutorMap", {
get: function() {
return this.parent ? this.parent.functionExecutorMap : this._functionExecutorMap;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphExecutor2.prototype, "weightMap", {
get: function() {
return this.parent ? this.parent.weightMap : this._weightMap;
},
set: function(weightMap) {
var weightIds = Object.keys(weightMap).map(function(key) {
return weightMap[key].map(function(tensor) {
return tensor.id;
});
});
this._weightIds = [].concat.apply([], __spread(weightIds));
this._weightMap = weightMap;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphExecutor2.prototype, "resourceManager", {
set: function(resourceManager) {
this._resourceManager = resourceManager;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphExecutor2.prototype, "inputs", {
get: function() {
return this._inputs.map(function(node) {
return {
name: node.name,
shape: node.attrParams["shape"] ? node.attrParams["shape"].value : void 0,
dtype: node.attrParams["dtype"] ? node.attrParams["dtype"].value : void 0
};
});
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphExecutor2.prototype, "outputs", {
get: function() {
return this._outputs.map(function(node) {
return {
name: node.name,
shape: node.attrParams["shape"] ? node.attrParams["shape"].value : void 0,
dtype: node.attrParams["dtype"] ? node.attrParams["dtype"].value : void 0
};
});
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphExecutor2.prototype, "inputNodes", {
get: function() {
return this._inputs.map(function(node) {
return node.signatureKey || node.name;
});
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphExecutor2.prototype, "outputNodes", {
get: function() {
return this._outputs.map(function(node) {
var name = node.signatureKey || node.name;
return node.defaultOutput ? name + ":" + node.defaultOutput : name;
});
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphExecutor2.prototype, "functions", {
get: function() {
var _this2 = this;
return Object.keys(this._functions).reduce(function(map, key) {
map[key] = _this2._functions[key].signature;
return map;
}, {});
},
enumerable: true,
configurable: true
});
GraphExecutor2.prototype.getCompilationKey = function(inputs, outputs) {
var sortedInputs = inputs.map(function(node) {
return node.name;
}).sort();
var sortedOutputs = outputs.map(function(node) {
return node.name;
}).sort();
return sortedInputs.join(this.SEPERATOR) + "--" + sortedOutputs.join(this.SEPERATOR);
};
GraphExecutor2.prototype.compile = function(inputs, outputs) {
var executionInfo = getExecutionSubgraph(inputs, outputs, this.weightMap, this._initNodes);
var missingInputs = executionInfo.missingInputs, dynamicNode = executionInfo.dynamicNode, syncInputs = executionInfo.syncInputs;
if (dynamicNode != null) {
throw new Error("This execution contains the node '" + dynamicNode.name + "', which has " + ("the dynamic op '" + dynamicNode.op + "'. Please use ") + "model.executeAsync() instead. Alternatively, to avoid the " + ("dynamic ops, specify the inputs [" + syncInputs + "]"));
}
if (missingInputs.length > 0) {
var outNames = outputs.map(function(n) {
return n.name;
});
var inNames = Object.keys(inputs);
throw new Error("Cannot compute the outputs [" + outNames + "] from the provided inputs " + ("[" + inNames + "]. Missing the following inputs: [" + missingInputs + "]"));
}
return getNodesInTopologicalOrder(this.graph, this.weightMap, executionInfo);
};
GraphExecutor2.prototype.execute = function(inputs, outputs) {
var _this2 = this;
inputs = this.mapInputs(inputs);
var names = Object.keys(inputs).sort();
this.checkInputs(inputs);
this.checkInputShapeAndType(inputs);
outputs = this.mapOutputs(outputs);
this.checkOutputs(outputs);
var inputNodes = names.map(function(name) {
return _this2.graph.nodes[parseNodeName(name)[0]];
});
var outputNodeNames = outputs.map(function(name) {
return parseNodeName(name)[0];
});
var outputNodes = outputNodeNames.map(function(name) {
return _this2.graph.nodes[name];
});
if (outputNodes.length === 0) {
outputNodes = this._outputs;
}
var compilationKey = this.getCompilationKey(inputNodes, outputNodes);
var orderedNodes = this.compiledMap.get(compilationKey);
if (orderedNodes == null) {
orderedNodes = this.compile(inputs, outputNodes);
this.compiledMap.set(compilationKey, orderedNodes);
}
var tensorArrayMap = {};
var tensorListMap = {};
return tfOps.tidy(function() {
var context = new ExecutionContext(_this2.weightMap, tensorArrayMap, tensorListMap, _this2.functionExecutorMap);
var tensorsMap = __assign({}, _this2.weightMap);
Object.keys(inputs).forEach(function(name) {
var _a = __read(parseNodeName(name), 2), nodeName = _a[0], index = _a[1];
var tensors2 = [];
tensors2[index] = inputs[name];
tensorsMap[nodeName] = tensors2;
});
var tensorsToKeep = _this2.getFrozenTensorIds(tensorsMap);
var intermediateTensorConsumerCount = {};
for (var i = 0; i < orderedNodes.length; i++) {
var node = orderedNodes[i];
if (!tensorsMap[node.name]) {
var tensors = executeOp$h(node, tensorsMap, context, _this2._resourceManager);
if (tfOps.util.isPromise(tensors)) {
throw new Error("The execution of the op '" + node.op + "' returned a promise. Please use model.executeAsync() instead.");
}
tensorsMap[node.name] = tensors;
_this2.checkTensorForDisposal(node.name, node, tensorsMap, context, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount);
}
}
if (_this2.parent == null) {
context.dispose(tensorsToKeep);
}
return outputs.map(function(name) {
return getTensor(name, tensorsMap, context);
});
});
};
GraphExecutor2.prototype.getFrozenTensorIds = function(tensorMap) {
var ids = [].concat.apply([], Object.keys(tensorMap).map(function(key) {
return tensorMap[key];
}).map(function(tensors) {
return tensors.map(function(tensor) {
return tensor.id;
});
}));
return new Set(ids);
};
GraphExecutor2.prototype.checkTensorForDisposal = function(nodeName, node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount) {
if (node.category === "control" || outputNames.indexOf(nodeName) !== -1) {
return;
}
tensorMap[nodeName].forEach(function(tensor) {
if (tensor != null) {
intermediateTensorConsumerCount[tensor.id] = (intermediateTensorConsumerCount[tensor.id] || 0) + node.children.length;
}
});
node.inputs.forEach(function(input) {
if (input.category !== "control") {
var tensors = getTensorsForCurrentContenxt(input.name, tensorMap, context);
if (tensors != null) {
tensors.forEach(function(tensor) {
if (tensor && !tensorsToKeep.has(tensor.id)) {
var count = intermediateTensorConsumerCount[tensor.id];
if (count === 1) {
tensor.dispose();
delete intermediateTensorConsumerCount[tensor.id];
} else if (count != null) {
intermediateTensorConsumerCount[tensor.id]--;
}
}
});
}
}
});
};
GraphExecutor2.prototype.executeAsync = function(inputs, outputs) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this._executeAsync(inputs, outputs)];
});
});
};
GraphExecutor2.prototype._executeAsync = function(inputs, outputs, isFunctionExecution, tensorArrayMap, tensorListMap) {
if (isFunctionExecution === void 0) {
isFunctionExecution = false;
}
if (tensorArrayMap === void 0) {
tensorArrayMap = {};
}
if (tensorListMap === void 0) {
tensorListMap = {};
}
return __awaiter(this, void 0, void 0, function() {
var context, tensorMap, results, outputIds, inputIds, keepIds;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!isFunctionExecution) {
inputs = this.mapInputs(inputs);
this.checkInputs(inputs);
this.checkInputShapeAndType(inputs);
outputs = this.mapOutputs(outputs);
this.checkOutputs(outputs);
}
context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap);
return [4, this.executeWithControlFlow(inputs, context, outputs, isFunctionExecution)];
case 1:
tensorMap = _a.sent();
results = outputs.map(function(name) {
return getTensor(name, tensorMap, context);
});
outputIds = results.map(function(t) {
return t.id;
});
inputIds = Object.keys(inputs).map(function(name) {
return inputs[name].id;
});
keepIds = new Set(__spread(outputIds, inputIds, this.weightIds));
Object.keys(tensorMap).forEach(function(key) {
var tensorArray = tensorMap[key];
tensorArray.forEach(function(tensor) {
if (tensor && !tensor.isDisposed && !keepIds.has(tensor.id)) {
tensor.dispose();
}
});
});
if (this.parent == null) {
context.dispose(keepIds);
}
return [2, results];
}
});
});
};
GraphExecutor2.prototype.executeFunctionAsync = function(inputs, tensorArrayMap, tensorListMap) {
return __awaiter(this, void 0, void 0, function() {
var mappedInputs;
var _this2 = this;
return __generator(this, function(_a) {
mappedInputs = inputs.reduce(function(map, tensor, index) {
map[_this2.inputs[index].name] = tensor;
return map;
}, {});
return [2, this._executeAsync(mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap)];
});
});
};
GraphExecutor2.prototype.executeWithControlFlow = function(inputs, context, outputNames, isFunctionExecution) {
return __awaiter(this, void 0, void 0, function() {
var names, inputNodes, outputNodeNames, outputNodes, _a, usedNodes, missingInputs, dynamicNode, syncInputs, stack, tensorsMap, intermediateTensorConsumerCount, tensorsToKeep, added, promises, missingOutputs, alternativeMsg;
var _this2 = this;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
names = Object.keys(inputs);
inputNodes = names.map(function(name) {
return _this2.graph.nodes[parseNodeName(name)[0]];
});
outputNodeNames = outputNames.map(function(name) {
return parseNodeName(name)[0];
});
outputNodes = outputNodeNames.map(function(name) {
return _this2.graph.nodes[name];
});
if (outputNodes.length === 0) {
outputNodes = this._outputs;
}
_a = getExecutionSubgraph(inputs, outputNodes, this.weightMap, this._initNodes), usedNodes = _a.usedNodes, missingInputs = _a.missingInputs, dynamicNode = _a.dynamicNode, syncInputs = _a.syncInputs;
stack = __spread(inputNodes, this.graph.weights, this._initNodes || []).map(function(node) {
return {node, contexts: context.currentContext};
});
tensorsMap = __assign({}, this.weightMap);
Object.keys(inputs).forEach(function(name) {
var _a2 = __read(parseNodeName(name), 2), nodeName = _a2[0], index = _a2[1];
var tensors = [];
tensors[index] = inputs[name];
tensorsMap[nodeName] = tensors;
});
intermediateTensorConsumerCount = {};
tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
added = {};
_b.label = 1;
case 1:
if (!(stack.length > 0))
return [3, 3];
promises = this.processStack(inputNodes, stack, context, tensorsMap, added, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount, usedNodes);
return [4, Promise.all(promises)];
case 2:
_b.sent();
return [3, 1];
case 3:
if (dynamicNode == null && !isFunctionExecution) {
console.warn("This model execution did not contain any nodes with control flow or dynamic output shapes. You can use model.execute() instead.");
}
missingOutputs = outputNodes.filter(function(node) {
return !isControlFlow(node) && !getTensor(node.name, tensorsMap, context);
}).map(function(node) {
return node.name;
});
if (missingOutputs.length > 0) {
alternativeMsg = "";
if (dynamicNode != null) {
alternativeMsg = "Alternatively, to avoid the dynamic ops, use model.execute() " + ("and specify the inputs [" + syncInputs + "]");
}
throw new Error("Cannot compute the outputs [" + missingOutputs + "] from the provided " + ("inputs [" + names + "]. Consider providing the following inputs: ") + ("[" + missingInputs + "]. " + alternativeMsg));
}
return [2, tensorsMap];
}
});
});
};
GraphExecutor2.prototype.processStack = function(inputNodes, stack, context, tensorMap, added, tensorsToKeep, outputNames, intermediateTensorConsumerCount, usedNodes) {
var _this2 = this;
var promises = [];
var _loop_1 = function() {
var _a, _b;
var item = stack.pop();
context.currentContext = item.contexts;
var nodeName = "";
if (item.node.op === "Enter" && getParamValue("isConstant", item.node, tensorMap, context)) {
_a = __read(getNodeNameAndIndex(item.node.name, context), 1), nodeName = _a[0];
}
if (tensorMap[item.node.name] == null) {
var tensors = executeOp$h(item.node, tensorMap, context, this_1._resourceManager);
if (!nodeName) {
_b = __read(getNodeNameAndIndex(item.node.name, context), 1), nodeName = _b[0];
}
var currentContext_1 = context.currentContext;
if (tfOps.util.isPromise(tensors)) {
promises.push(tensors.then(function(t) {
tensorMap[nodeName] = t;
context.currentContext = currentContext_1;
_this2.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
_this2.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
return t;
}));
} else {
tensorMap[nodeName] = tensors;
this_1.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
this_1.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
}
} else {
this_1.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
}
};
var this_1 = this;
while (stack.length > 0) {
_loop_1();
}
return promises;
};
GraphExecutor2.prototype.processChildNodes = function(node, stack, context, tensorMap, added, usedNodes) {
node.children.forEach(function(childNode) {
var _a = __read(getNodeNameAndIndex(childNode.name, context), 1), nodeName = _a[0];
if (added[nodeName] || !usedNodes.has(childNode.name)) {
return;
}
if (childNode.op === "Merge") {
if (childNode.inputNames.some(function(name) {
return !!getTensor(name, tensorMap, context);
})) {
added[nodeName] = true;
stack.push({contexts: context.currentContext, node: childNode});
}
} else if (childNode.inputNames.every(function(name) {
return !!getTensor(name, tensorMap, context);
})) {
added[nodeName] = true;
stack.push({contexts: context.currentContext, node: childNode});
}
});
};
GraphExecutor2.prototype.dispose = function() {
var _this2 = this;
Object.keys(this.weightMap).forEach(function(key) {
return _this2.weightMap[key].forEach(function(tensor) {
return tensor.dispose();
});
});
};
GraphExecutor2.prototype.checkInputShapeAndType = function(inputs) {
var _this2 = this;
Object.keys(inputs).forEach(function(name) {
var input = inputs[name];
var _a = __read(parseNodeName(name), 1), nodeName = _a[0];
var node = _this2.graph.nodes[nodeName];
if (node.attrParams["shape"] && node.attrParams["shape"].value) {
var shape_1 = node.attrParams["shape"].value;
var match = shape_1.length === input.shape.length && input.shape.every(function(dim, index) {
return shape_1[index] === -1 || shape_1[index] === dim;
});
tfOps.util.assert(match, function() {
return "The shape of dict['" + node.name + "'] provided in " + ("model.execute(dict) must be [" + shape_1 + "], but was ") + ("[" + input.shape + "]");
});
}
if (node.attrParams["dtype"] && node.attrParams["dtype"].value) {
tfOps.util.assert(input.dtype === node.attrParams["dtype"].value, function() {
return "The dtype of dict['" + node.name + "'] provided in model.execute(dict) must be " + (node.attrParams["dtype"].value + ", but was " + input.dtype);
});
}
});
};
GraphExecutor2.prototype.mapInputs = function(inputs) {
var result = {};
for (var inputName in inputs) {
if (this._signature != null && this._signature.inputs != null && this._signature.inputs[inputName] != null) {
var tensor = this._signature.inputs[inputName];
result[tensor.name] = inputs[inputName];
} else {
result[inputName] = inputs[inputName];
}
}
return result;
};
GraphExecutor2.prototype.checkInputs = function(inputs) {
var _this2 = this;
var notInGraph = Object.keys(inputs).filter(function(name) {
var _a = __read(parseNodeName(name), 1), nodeName = _a[0];
return _this2.graph.nodes[nodeName] == null;
});
if (notInGraph.length > 0) {
throw new Error("The dict provided in model.execute(dict) has " + ("keys: [" + notInGraph + "] that are not part of graph"));
}
};
GraphExecutor2.prototype.mapOutputs = function(outputs) {
var _this2 = this;
return outputs.map(function(name) {
if (_this2._signature != null && _this2._signature.outputs != null && _this2._signature.outputs[name] != null) {
var tensor = _this2._signature.outputs[name];
return tensor.name;
}
return name;
}, {});
};
GraphExecutor2.prototype.checkOutputs = function(outputs) {
var _this2 = this;
outputs.forEach(function(name) {
var _a = __read(parseNodeName(name), 1), normalizedName = _a[0];
if (!_this2.graph.nodes[normalizedName]) {
throw new Error("The output '" + name + "' is not found in the graph");
}
});
};
return GraphExecutor2;
}();
var ResourceManager = function() {
function ResourceManager2(hashTableNameToHandle, hashTableMap) {
if (hashTableNameToHandle === void 0) {
hashTableNameToHandle = {};
}
if (hashTableMap === void 0) {
hashTableMap = {};
}
this.hashTableNameToHandle = hashTableNameToHandle;
this.hashTableMap = hashTableMap;
}
ResourceManager2.prototype.addHashTable = function(name, hashTable2) {
this.hashTableNameToHandle[name] = hashTable2.handle;
this.hashTableMap[hashTable2.id] = hashTable2;
};
ResourceManager2.prototype.getHashTableHandleByName = function(name) {
return this.hashTableNameToHandle[name];
};
ResourceManager2.prototype.getHashTableById = function(id) {
return this.hashTableMap[id];
};
ResourceManager2.prototype.dispose = function() {
for (var key in this.hashTableMap) {
this.hashTableMap[key].clearAndClose();
delete this.hashTableMap[key];
}
for (var name_1 in this.hashTableNameToHandle) {
this.hashTableNameToHandle[name_1].dispose();
delete this.hashTableNameToHandle[name_1];
}
};
return ResourceManager2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TFHUB_SEARCH_PARAM = "?tfjs-format=file";
var DEFAULT_MODEL_NAME = "model.json";
var GraphModel = function() {
function GraphModel2(modelUrl, loadOptions) {
if (loadOptions === void 0) {
loadOptions = {};
}
this.modelUrl = modelUrl;
this.loadOptions = loadOptions;
this.version = "n/a";
if (loadOptions == null) {
this.loadOptions = {};
}
this.resourceManager = new ResourceManager();
}
Object.defineProperty(GraphModel2.prototype, "modelVersion", {
get: function() {
return this.version;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphModel2.prototype, "inputNodes", {
get: function() {
return this.executor.inputNodes;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphModel2.prototype, "outputNodes", {
get: function() {
return this.executor.outputNodes;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphModel2.prototype, "inputs", {
get: function() {
return this.executor.inputs;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphModel2.prototype, "outputs", {
get: function() {
return this.executor.outputs;
},
enumerable: true,
configurable: true
});
Object.defineProperty(GraphModel2.prototype, "weights", {
get: function() {
return this.executor.weightMap;
},
enumerable: true,
configurable: true
});
GraphModel2.prototype.findIOHandler = function() {
var path = this.modelUrl;
if (path.load != null) {
this.handler = path;
} else if (this.loadOptions.requestInit != null) {
this.handler = tfOps.io.browserHTTPRequest(path, this.loadOptions);
} else {
var handlers = tfOps.io.getLoadHandlers(path, this.loadOptions);
if (handlers.length === 0) {
handlers.push(tfOps.io.browserHTTPRequest(path, this.loadOptions));
} else if (handlers.length > 1) {
throw new Error("Found more than one (" + handlers.length + ") load handlers for " + ("URL '" + [path] + "'"));
}
this.handler = handlers[0];
}
};
GraphModel2.prototype.load = function() {
return __awaiter(this, void 0, void 0, function() {
var artifacts;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
this.findIOHandler();
if (this.handler.load == null) {
throw new Error("Cannot proceed with model loading because the IOHandler provided does not have the `load` method implemented.");
}
return [4, this.handler.load()];
case 1:
artifacts = _a.sent();
return [2, this.loadSync(artifacts)];
}
});
});
};
GraphModel2.prototype.loadSync = function(artifacts) {
this.artifacts = artifacts;
var graph2 = this.artifacts.modelTopology;
var signature = {};
if (this.artifacts.userDefinedMetadata != null) {
signature = this.artifacts.userDefinedMetadata.signature;
}
this.version = graph2.versions.producer + "." + graph2.versions.minConsumer;
var weightMap = tfOps.io.decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(OperationMapper.Instance.transformGraph(graph2, signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap);
this.executor.resourceManager = this.resourceManager;
if (artifacts.modelInitializer != null) {
var initializer = OperationMapper.Instance.transformGraph(artifacts.modelInitializer);
this.initializer = new GraphExecutor(initializer);
this.initializer.weightMap = this.executor.weightMap;
this.initializer.resourceManager = this.resourceManager;
this.initializer.executeAsync({}, []);
}
return true;
};
GraphModel2.prototype.save = function(handlerOrURL, config) {
return __awaiter(this, void 0, void 0, function() {
var handlers;
return __generator(this, function(_a) {
if (typeof handlerOrURL === "string") {
handlers = tfOps.io.getSaveHandlers(handlerOrURL);
if (handlers.length === 0) {
throw new Error("Cannot find any save handlers for URL '" + handlerOrURL + "'");
} else if (handlers.length > 1) {
throw new Error("Found more than one (" + handlers.length + ") save handlers for " + ("URL '" + handlerOrURL + "'"));
}
handlerOrURL = handlers[0];
}
if (handlerOrURL.save == null) {
throw new Error("GraphModel.save() cannot proceed because the IOHandler provided does not have the `save` attribute defined.");
}
return [2, handlerOrURL.save(this.artifacts)];
});
});
};
GraphModel2.prototype.predict = function(inputs, config) {
return this.execute(inputs, this.outputNodes);
};
GraphModel2.prototype.normalizeInputs = function(inputs) {
if (!(inputs instanceof tfOps.Tensor) && !Array.isArray(inputs)) {
return inputs;
}
inputs = Array.isArray(inputs) ? inputs : [inputs];
if (inputs.length !== this.inputNodes.length) {
throw new Error("Input tensor count mismatch," + ("the graph model has " + this.inputNodes.length + " placeholders, ") + ("while there are " + inputs.length + " input tensors."));
}
return this.inputNodes.reduce(function(map, inputName, i) {
map[inputName] = inputs[i];
return map;
}, {});
};
GraphModel2.prototype.normalizeOutputs = function(outputs) {
outputs = outputs || this.outputNodes;
return !Array.isArray(outputs) ? [outputs] : outputs;
};
GraphModel2.prototype.execute = function(inputs, outputs) {
inputs = this.normalizeInputs(inputs);
outputs = this.normalizeOutputs(outputs);
var result = this.executor.execute(inputs, outputs);
return result.length > 1 ? result : result[0];
};
GraphModel2.prototype.executeAsync = function(inputs, outputs) {
return __awaiter(this, void 0, void 0, function() {
var result;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
inputs = this.normalizeInputs(inputs);
outputs = this.normalizeOutputs(outputs);
return [4, this.executor.executeAsync(inputs, outputs)];
case 1:
result = _a.sent();
return [2, result.length > 1 ? result : result[0]];
}
});
});
};
GraphModel2.prototype.convertTensorMapToTensorsMap = function(map) {
return Object.keys(map).reduce(function(newMap, key) {
newMap[key] = [map[key]];
return newMap;
}, {});
};
GraphModel2.prototype.dispose = function() {
this.executor.dispose();
if (this.initializer) {
this.initializer.dispose();
}
this.resourceManager.dispose();
};
return GraphModel2;
}();
function loadGraphModel(modelUrl, options) {
if (options === void 0) {
options = {};
}
return __awaiter(this, void 0, void 0, function() {
var model;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (modelUrl == null) {
throw new Error("modelUrl in loadGraphModel() cannot be null. Please provide a url or an IOHandler that loads the model");
}
if (options == null) {
options = {};
}
if (options.fromTFHub) {
if (modelUrl.load == null) {
if (!modelUrl.endsWith("/")) {
modelUrl = modelUrl + "/";
}
modelUrl = "" + modelUrl + DEFAULT_MODEL_NAME + TFHUB_SEARCH_PARAM;
}
}
model = new GraphModel(modelUrl, options);
return [4, model.load()];
case 1:
_a.sent();
return [2, model];
}
});
});
}
/** @license See the LICENSE file. */
var version = "2.7.0";
exports.GraphModel = GraphModel;
exports.deregisterOp = deregisterOp;
exports.loadGraphModel = loadGraphModel;
exports.registerOp = registerOp;
exports.version_converter = version;
});
// empty:/home/vlado/dev/human/node_modules/string_decoder/lib/string_decoder.js
var require_string_decoder = __commonJS(() => {
});
// node_modules/@tensorflow/tfjs-data/dist/tf-data.node.js
var require_tf_data_node = __commonJS((exports) => {
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
"use strict";
Object.defineProperty(exports, "__esModule", {value: true});
var tf = require_tf_core_node();
/*! *****************************************************************************
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at http://www.apache.org/licenses/LICENSE-2.0
THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
MERCHANTABLITY OR NON-INFRINGEMENT.
See the Apache Version 2.0 License for specific language governing permissions
and limitations under the License.
***************************************************************************** */
var extendStatics = function(d, b) {
extendStatics = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(d2, b2) {
d2.__proto__ = b2;
} || function(d2, b2) {
for (var p in b2)
if (b2.hasOwnProperty(p))
d2[p] = b2[p];
};
return extendStatics(d, b);
};
function __extends(d, b) {
extendStatics(d, b);
function __() {
this.constructor = d;
}
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
}
function __awaiter(thisArg, _arguments, P, generator2) {
return new (P || (P = Promise))(function(resolve, reject) {
function fulfilled(value) {
try {
step(generator2.next(value));
} catch (e) {
reject(e);
}
}
function rejected(value) {
try {
step(generator2["throw"](value));
} catch (e) {
reject(e);
}
}
function step(result) {
result.done ? resolve(result.value) : new P(function(resolve2) {
resolve2(result.value);
}).then(fulfilled, rejected);
}
step((generator2 = generator2.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = {label: 0, sent: function() {
if (t[0] & 1)
throw t[1];
return t[1];
}, trys: [], ops: []}, f, y, t, g;
return g = {next: verb(0), throw: verb(1), return: verb(2)}, typeof Symbol === "function" && (g[Symbol.iterator] = function() {
return this;
}), g;
function verb(n) {
return function(v) {
return step([n, v]);
};
}
function step(op) {
if (f)
throw new TypeError("Generator is already executing.");
while (_)
try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done)
return t;
if (y = 0, t)
op = [op[0] & 2, t.value];
switch (op[0]) {
case 0:
case 1:
t = op;
break;
case 4:
_.label++;
return {value: op[1], done: false};
case 5:
_.label++;
y = op[1];
op = [0];
continue;
case 7:
op = _.ops.pop();
_.trys.pop();
continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) {
_ = 0;
continue;
}
if (op[0] === 3 && (!t || op[1] > t[0] && op[1] < t[3])) {
_.label = op[1];
break;
}
if (op[0] === 6 && _.label < t[1]) {
_.label = t[1];
t = op;
break;
}
if (t && _.label < t[2]) {
_.label = t[2];
_.ops.push(op);
break;
}
if (t[2])
_.ops.pop();
_.trys.pop();
continue;
}
op = body.call(thisArg, _);
} catch (e) {
op = [6, e];
y = 0;
} finally {
f = t = 0;
}
if (op[0] & 5)
throw op[1];
return {value: op[0] ? op[1] : void 0, done: true};
}
}
var commonjsGlobal = typeof globalThis !== "undefined" ? globalThis : typeof window !== "undefined" ? window : typeof global !== "undefined" ? global : typeof self !== "undefined" ? self : {};
function createCommonjsModule(fn, module2) {
return module2 = {exports: {}}, fn(module2, module2.exports), module2.exports;
}
var alea = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function Alea(seed) {
var me = this, mash = Mash();
me.next = function() {
var t = 2091639 * me.s0 + me.c * 23283064365386963e-26;
me.s0 = me.s1;
me.s1 = me.s2;
return me.s2 = t - (me.c = t | 0);
};
me.c = 1;
me.s0 = mash(" ");
me.s1 = mash(" ");
me.s2 = mash(" ");
me.s0 -= mash(seed);
if (me.s0 < 0) {
me.s0 += 1;
}
me.s1 -= mash(seed);
if (me.s1 < 0) {
me.s1 += 1;
}
me.s2 -= mash(seed);
if (me.s2 < 0) {
me.s2 += 1;
}
mash = null;
}
function copy(f, t) {
t.c = f.c;
t.s0 = f.s0;
t.s1 = f.s1;
t.s2 = f.s2;
return t;
}
function impl(seed, opts) {
var xg = new Alea(seed), state = opts && opts.state, prng = xg.next;
prng.int32 = function() {
return xg.next() * 4294967296 | 0;
};
prng.double = function() {
return prng() + (prng() * 2097152 | 0) * 11102230246251565e-32;
};
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
function Mash() {
var n = 4022871197;
var mash = function(data) {
data = data.toString();
for (var i = 0; i < data.length; i++) {
n += data.charCodeAt(i);
var h = 0.02519603282416938 * n;
n = h >>> 0;
h -= n;
h *= n;
n = h >>> 0;
h -= n;
n += h * 4294967296;
}
return (n >>> 0) * 23283064365386963e-26;
};
return mash;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.alea = impl;
}
})(commonjsGlobal, module2, false);
});
var xor128 = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.next = function() {
var t = me.x ^ me.x << 11;
me.x = me.y;
me.y = me.z;
me.z = me.w;
return me.w ^= me.w >>> 19 ^ t ^ t >>> 8;
};
if (seed === (seed | 0)) {
me.x = seed;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xor128 = impl;
}
})(commonjsGlobal, module2, false);
});
var xorwow = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.next = function() {
var t = me.x ^ me.x >>> 2;
me.x = me.y;
me.y = me.z;
me.z = me.w;
me.w = me.v;
return (me.d = me.d + 362437 | 0) + (me.v = me.v ^ me.v << 4 ^ (t ^ t << 1)) | 0;
};
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.v = 0;
if (seed === (seed | 0)) {
me.x = seed;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
if (k == strseed.length) {
me.d = me.x << 10 ^ me.x >>> 4;
}
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
t.v = f.v;
t.d = f.d;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xorwow = impl;
}
})(commonjsGlobal, module2, false);
});
var xorshift7 = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this;
me.next = function() {
var X = me.x, i = me.i, t, v;
t = X[i];
t ^= t >>> 7;
v = t ^ t << 24;
t = X[i + 1 & 7];
v ^= t ^ t >>> 10;
t = X[i + 3 & 7];
v ^= t ^ t >>> 3;
t = X[i + 4 & 7];
v ^= t ^ t << 7;
t = X[i + 7 & 7];
t = t ^ t << 13;
v ^= t ^ t << 9;
X[i] = v;
me.i = i + 1 & 7;
return v;
};
function init(me2, seed2) {
var j, w, X = [];
if (seed2 === (seed2 | 0)) {
w = X[0] = seed2;
} else {
seed2 = "" + seed2;
for (j = 0; j < seed2.length; ++j) {
X[j & 7] = X[j & 7] << 15 ^ seed2.charCodeAt(j) + X[j + 1 & 7] << 13;
}
}
while (X.length < 8)
X.push(0);
for (j = 0; j < 8 && X[j] === 0; ++j)
;
if (j == 8)
w = X[7] = -1;
else
w = X[j];
me2.x = X;
me2.i = 0;
for (j = 256; j > 0; --j) {
me2.next();
}
}
init(me, seed);
}
function copy(f, t) {
t.x = f.x.slice();
t.i = f.i;
return t;
}
function impl(seed, opts) {
if (seed == null)
seed = +new Date();
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.x)
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xorshift7 = impl;
}
})(commonjsGlobal, module2, false);
});
var xor4096 = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this;
me.next = function() {
var w = me.w, X = me.X, i = me.i, t, v;
me.w = w = w + 1640531527 | 0;
v = X[i + 34 & 127];
t = X[i = i + 1 & 127];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
v = X[i] = v ^ t;
me.i = i;
return v + (w ^ w >>> 16) | 0;
};
function init(me2, seed2) {
var t, v, i, j, w, X = [], limit = 128;
if (seed2 === (seed2 | 0)) {
v = seed2;
seed2 = null;
} else {
seed2 = seed2 + "\0";
v = 0;
limit = Math.max(limit, seed2.length);
}
for (i = 0, j = -32; j < limit; ++j) {
if (seed2)
v ^= seed2.charCodeAt((j + 32) % seed2.length);
if (j === 0)
w = v;
v ^= v << 10;
v ^= v >>> 15;
v ^= v << 4;
v ^= v >>> 13;
if (j >= 0) {
w = w + 1640531527 | 0;
t = X[j & 127] ^= v + w;
i = t == 0 ? i + 1 : 0;
}
}
if (i >= 128) {
X[(seed2 && seed2.length || 0) & 127] = -1;
}
i = 127;
for (j = 4 * 128; j > 0; --j) {
v = X[i + 34 & 127];
t = X[i = i + 1 & 127];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
X[i] = v ^ t;
}
me2.w = w;
me2.X = X;
me2.i = i;
}
init(me, seed);
}
function copy(f, t) {
t.i = f.i;
t.w = f.w;
t.X = f.X.slice();
return t;
}
function impl(seed, opts) {
if (seed == null)
seed = +new Date();
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.X)
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xor4096 = impl;
}
})(commonjsGlobal, module2, false);
});
var tychei = createCommonjsModule(function(module2) {
(function(global2, module3, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.next = function() {
var b = me.b, c = me.c, d = me.d, a = me.a;
b = b << 25 ^ b >>> 7 ^ c;
c = c - d | 0;
d = d << 24 ^ d >>> 8 ^ a;
a = a - b | 0;
me.b = b = b << 20 ^ b >>> 12 ^ c;
me.c = c = c - d | 0;
me.d = d << 16 ^ c >>> 16 ^ a;
return me.a = a - b | 0;
};
me.a = 0;
me.b = 0;
me.c = 2654435769 | 0;
me.d = 1367130551;
if (seed === Math.floor(seed)) {
me.a = seed / 4294967296 | 0;
me.b = seed | 0;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 20; k++) {
me.b ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.a = f.a;
t.b = f.b;
t.c = f.c;
t.d = f.d;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module3 && module3.exports) {
module3.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.tychei = impl;
}
})(commonjsGlobal, module2, false);
});
var seedrandom = createCommonjsModule(function(module2) {
(function(pool, math) {
var global2 = this, width = 256, chunks = 6, digits = 52, rngname = "random", startdenom = math.pow(width, chunks), significance = math.pow(2, digits), overflow = significance * 2, mask = width - 1, nodecrypto;
function seedrandom2(seed, options, callback) {
var key = [];
options = options == true ? {entropy: true} : options || {};
var shortseed = mixkey(flatten(options.entropy ? [seed, tostring(pool)] : seed == null ? autoseed() : seed, 3), key);
var arc4 = new ARC4(key);
var prng = function() {
var n = arc4.g(chunks), d = startdenom, x = 0;
while (n < significance) {
n = (n + x) * width;
d *= width;
x = arc4.g(1);
}
while (n >= overflow) {
n /= 2;
d /= 2;
x >>>= 1;
}
return (n + x) / d;
};
prng.int32 = function() {
return arc4.g(4) | 0;
};
prng.quick = function() {
return arc4.g(4) / 4294967296;
};
prng.double = prng;
mixkey(tostring(arc4.S), pool);
return (options.pass || callback || function(prng2, seed2, is_math_call, state) {
if (state) {
if (state.S) {
copy(state, arc4);
}
prng2.state = function() {
return copy(arc4, {});
};
}
if (is_math_call) {
math[rngname] = prng2;
return seed2;
} else
return prng2;
})(prng, shortseed, "global" in options ? options.global : this == math, options.state);
}
math["seed" + rngname] = seedrandom2;
function ARC4(key) {
var t, keylen = key.length, me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
if (!keylen) {
key = [keylen++];
}
while (i < width) {
s[i] = i++;
}
for (i = 0; i < width; i++) {
s[i] = s[j = mask & j + key[i % keylen] + (t = s[i])];
s[j] = t;
}
(me.g = function(count) {
var t2, r = 0, i2 = me.i, j2 = me.j, s2 = me.S;
while (count--) {
t2 = s2[i2 = mask & i2 + 1];
r = r * width + s2[mask & (s2[i2] = s2[j2 = mask & j2 + t2]) + (s2[j2] = t2)];
}
me.i = i2;
me.j = j2;
return r;
})(width);
}
function copy(f, t) {
t.i = f.i;
t.j = f.j;
t.S = f.S.slice();
return t;
}
function flatten(obj, depth) {
var result = [], typ = typeof obj, prop;
if (depth && typ == "object") {
for (prop in obj) {
try {
result.push(flatten(obj[prop], depth - 1));
} catch (e) {
}
}
}
return result.length ? result : typ == "string" ? obj : obj + "\0";
}
function mixkey(seed, key) {
var stringseed = seed + "", smear, j = 0;
while (j < stringseed.length) {
key[mask & j] = mask & (smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++);
}
return tostring(key);
}
function autoseed() {
try {
var out;
if (nodecrypto && (out = nodecrypto.randomBytes)) {
out = out(width);
} else {
out = new Uint8Array(width);
(global2.crypto || global2.msCrypto).getRandomValues(out);
}
return tostring(out);
} catch (e) {
var browser = global2.navigator, plugins = browser && browser.plugins;
return [+new Date(), global2, plugins, global2.screen, tostring(pool)];
}
}
function tostring(a) {
return String.fromCharCode.apply(0, a);
}
mixkey(math.random(), pool);
if (module2.exports) {
module2.exports = seedrandom2;
try {
nodecrypto = require_crypto();
} catch (ex) {
}
}
})([], Math);
});
seedrandom.alea = alea;
seedrandom.xor128 = xor128;
seedrandom.xorwow = xorwow;
seedrandom.xorshift7 = xorshift7;
seedrandom.xor4096 = xor4096;
seedrandom.tychei = tychei;
var seedrandom$1 = seedrandom;
var seedrandom_1 = seedrandom$1.alea;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
function deepMap(input, mapFn) {
return deepMapInternal(input, mapFn);
}
function deepMapInternal(input, mapFn, seen, containedIn) {
if (seen === void 0) {
seen = new Map();
}
if (containedIn === void 0) {
containedIn = new Set();
}
if (input == null) {
return null;
}
if (containedIn.has(input)) {
throw new Error("Circular references are not supported.");
}
if (seen.has(input)) {
return seen.get(input);
}
var result = mapFn(input);
if (result.recurse && result.value !== null) {
throw new Error("A deep map function may not return both a value and recurse=true.");
}
if (!result.recurse) {
seen.set(input, result.value);
return result.value;
} else if (isIterable(input)) {
var mappedIterable = Array.isArray(input) ? [] : {};
containedIn.add(input);
for (var k in input) {
var child = input[k];
var childResult = deepMapInternal(child, mapFn, seen, containedIn);
mappedIterable[k] = childResult;
}
containedIn.delete(input);
return mappedIterable;
} else {
throw new Error("Can't recurse into non-iterable type: " + input);
}
}
function deepZip(inputs, zipFn) {
if (zipFn === void 0) {
zipFn = zipToList;
}
return deepZipInternal(inputs, zipFn);
}
function deepZipInternal(inputs, zipFn, containedIn) {
if (containedIn === void 0) {
containedIn = new Set();
}
var input = inputs[0];
if (containedIn.has(input)) {
throw new Error("Circular references are not supported.");
}
var result = zipFn(inputs);
if (result.recurse && result.value !== null) {
throw new Error("A deep zip function may not return both a value and recurse=true.");
}
if (!result.recurse) {
return result.value;
} else if (isIterable(input)) {
var mappedIterable = Array.isArray(input) ? [] : {};
containedIn.add(input);
var _loop_1 = function(k2) {
var children = inputs.map(function(x) {
return x[k2];
});
var childResult = deepZipInternal(children, zipFn, containedIn);
mappedIterable[k2] = childResult;
};
for (var k in input) {
_loop_1(k);
}
containedIn.delete(input);
return mappedIterable;
} else {
throw new Error("Can't recurse into non-iterable type: " + input);
}
}
function zipToList(x) {
if (x === null) {
return null;
}
if (isIterable(x[0])) {
return {value: null, recurse: true};
} else {
return {value: x, recurse: false};
}
}
function deepMapAndAwaitAll(input, mapFn) {
return __awaiter(this, void 0, void 0, function() {
var seen, _i, _a, key, value, mappedValue, result;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
seen = new Map();
deepMapInternal(input, mapFn, seen);
_i = 0, _a = Array.from(seen.keys());
_b.label = 1;
case 1:
if (!(_i < _a.length))
return [3, 4];
key = _a[_i];
value = seen.get(key);
if (!tf.util.isPromise(value))
return [3, 3];
return [4, value];
case 2:
mappedValue = _b.sent();
seen.set(key, mappedValue);
_b.label = 3;
case 3:
_i++;
return [3, 1];
case 4:
result = deepMapInternal(input, mapFn, seen);
return [2, result];
}
});
});
}
function isIterable(obj) {
return obj != null && !ArrayBuffer.isView(obj) && (Array.isArray(obj) || typeof obj === "object" && !(obj instanceof tf.Tensor));
}
function canTensorify(obj) {
return obj == null || isPrimitive(obj) || Array.isArray(obj) || typeof obj === "object" && obj instanceof tf.Tensor || tf.util.isTypedArray(obj);
}
function isPrimitive(value) {
return value === null || typeof value !== "object" && typeof value !== "function";
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
function deepClone(container) {
return deepMap(container, cloneIfTensor);
}
function cloneIfTensor(item) {
if (item instanceof tf.Tensor) {
return {value: item.clone(), recurse: false};
} else if (isIterable(item)) {
return {value: null, recurse: true};
} else {
return {value: item, recurse: false};
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var RingBuffer = function() {
function RingBuffer2(capacity) {
this.capacity = capacity;
this.begin = 0;
this.end = 0;
if (capacity == null) {
throw new RangeError("Can't create a ring buffer of unknown capacity.");
}
if (capacity < 1) {
throw new RangeError("Can't create ring buffer of capacity < 1.");
}
this.data = new Array(capacity);
this.doubledCapacity = 2 * capacity;
}
RingBuffer2.prototype.wrap = function(index) {
while (index < 0) {
index += this.doubledCapacity;
}
return index % this.doubledCapacity;
};
RingBuffer2.prototype.get = function(index) {
if (index < 0) {
throw new RangeError("Can't get item at a negative index.");
}
return this.data[index % this.capacity];
};
RingBuffer2.prototype.set = function(index, value) {
if (index < 0) {
throw new RangeError("Can't set item at a negative index.");
}
this.data[index % this.capacity] = value;
};
RingBuffer2.prototype.length = function() {
var length = this.end - this.begin;
if (length < 0) {
length = this.doubledCapacity + length;
}
return length;
};
RingBuffer2.prototype.isFull = function() {
return this.length() === this.capacity;
};
RingBuffer2.prototype.isEmpty = function() {
return this.length() === 0;
};
RingBuffer2.prototype.push = function(value) {
if (this.isFull()) {
throw new RangeError("Ring buffer is full.");
}
this.set(this.end, value);
this.end = this.wrap(this.end + 1);
};
RingBuffer2.prototype.pushAll = function(values) {
for (var _i = 0, values_1 = values; _i < values_1.length; _i++) {
var value = values_1[_i];
this.push(value);
}
};
RingBuffer2.prototype.pop = function() {
if (this.isEmpty()) {
throw new RangeError("Ring buffer is empty.");
}
this.end = this.wrap(this.end - 1);
var result = this.get(this.end);
this.set(this.end, void 0);
return result;
};
RingBuffer2.prototype.unshift = function(value) {
if (this.isFull()) {
throw new RangeError("Ring buffer is full.");
}
this.begin = this.wrap(this.begin - 1);
this.set(this.begin, value);
};
RingBuffer2.prototype.shift = function() {
if (this.isEmpty()) {
throw new RangeError("Ring buffer is empty.");
}
var result = this.get(this.begin);
this.set(this.begin, void 0);
this.begin = this.wrap(this.begin + 1);
return result;
};
RingBuffer2.prototype.shuffleExcise = function(relativeIndex) {
if (this.isEmpty()) {
throw new RangeError("Ring buffer is empty.");
}
var index = this.wrap(this.begin + relativeIndex);
var result = this.get(index);
this.set(index, this.pop());
return result;
};
return RingBuffer2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var GrowingRingBuffer = function(_super) {
__extends(GrowingRingBuffer2, _super);
function GrowingRingBuffer2() {
return _super.call(this, GrowingRingBuffer2.INITIAL_CAPACITY) || this;
}
GrowingRingBuffer2.prototype.isFull = function() {
return false;
};
GrowingRingBuffer2.prototype.push = function(value) {
if (_super.prototype.isFull.call(this)) {
this.expand();
}
_super.prototype.push.call(this, value);
};
GrowingRingBuffer2.prototype.unshift = function(value) {
if (_super.prototype.isFull.call(this)) {
this.expand();
}
_super.prototype.unshift.call(this, value);
};
GrowingRingBuffer2.prototype.expand = function() {
var newCapacity = this.capacity * 2;
var newData = new Array(newCapacity);
var len = this.length();
for (var i = 0; i < len; i++) {
newData[i] = this.get(this.wrap(this.begin + i));
}
this.data = newData;
this.capacity = newCapacity;
this.doubledCapacity = 2 * this.capacity;
this.begin = 0;
this.end = len;
};
GrowingRingBuffer2.INITIAL_CAPACITY = 32;
return GrowingRingBuffer2;
}(RingBuffer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
function iteratorFromItems(items) {
return new ArrayIterator(items);
}
function iteratorFromFunction(func2) {
return new FunctionCallIterator(func2);
}
function iteratorFromConcatenated(baseIterators, baseErrorHandler) {
return new ChainedIterator(baseIterators, baseErrorHandler);
}
function iteratorFromZipped(iterators, mismatchMode) {
if (mismatchMode === void 0) {
mismatchMode = ZipMismatchMode.FAIL;
}
return new ZipIterator(iterators, mismatchMode);
}
var LazyIterator = function() {
function LazyIterator2() {
}
LazyIterator2.prototype.toArray = function() {
return __awaiter(this, void 0, void 0, function() {
var result, x;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
result = [];
return [4, this.next()];
case 1:
x = _a.sent();
_a.label = 2;
case 2:
if (!!x.done)
return [3, 4];
result.push(x.value);
return [4, this.next()];
case 3:
x = _a.sent();
return [3, 2];
case 4:
return [2, result];
}
});
});
};
LazyIterator2.prototype.toArrayForTest = function() {
return __awaiter(this, void 0, void 0, function() {
var stream, result, x;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
stream = this.prefetch(100);
result = [];
return [4, stream.next()];
case 1:
x = _a.sent();
_a.label = 2;
case 2:
if (!!x.done)
return [3, 4];
result.push(x.value);
return [4, stream.next()];
case 3:
x = _a.sent();
return [3, 2];
case 4:
return [2, result];
}
});
});
};
LazyIterator2.prototype.resolveFully = function() {
return __awaiter(this, void 0, void 0, function() {
var x;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.next()];
case 1:
x = _a.sent();
_a.label = 2;
case 2:
if (!!x.done)
return [3, 4];
return [4, this.next()];
case 3:
x = _a.sent();
return [3, 2];
case 4:
return [2];
}
});
});
};
LazyIterator2.prototype.resolveWhile = function(predicate) {
return __awaiter(this, void 0, void 0, function() {
var x, shouldContinue;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.next()];
case 1:
x = _a.sent();
shouldContinue = predicate(x.value);
_a.label = 2;
case 2:
if (!(!x.done && shouldContinue))
return [3, 4];
return [4, this.next()];
case 3:
x = _a.sent();
shouldContinue = predicate(x.value);
return [3, 2];
case 4:
return [2];
}
});
});
};
LazyIterator2.prototype.handleErrors = function(handler) {
return new ErrorHandlingLazyIterator(this, handler);
};
LazyIterator2.prototype.filter = function(predicate) {
return new FilterIterator(this, predicate);
};
LazyIterator2.prototype.map = function(transform) {
return new MapIterator(this, transform);
};
LazyIterator2.prototype.mapAsync = function(transform) {
return new AsyncMapIterator(this, transform);
};
LazyIterator2.prototype.serialMapAsync = function(transform) {
return new AsyncMapIterator(this, transform).serial();
};
LazyIterator2.prototype.flatmap = function(transform) {
return new FlatmapIterator(this, transform);
};
LazyIterator2.prototype.forEachAsync = function(f) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.map(f).resolveFully()];
});
});
};
LazyIterator2.prototype.serialForEach = function(f) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.serialMapAsync(f).resolveWhile(function(x) {
return x === true;
})];
});
});
};
LazyIterator2.prototype.rowMajorBatch = function(batchSize, smallLastBatch) {
if (smallLastBatch === void 0) {
smallLastBatch = true;
}
return new RowMajorBatchIterator(this, batchSize, smallLastBatch);
};
LazyIterator2.prototype.columnMajorBatch = function(batchSize, smallLastBatch, zipFn) {
if (smallLastBatch === void 0) {
smallLastBatch = true;
}
if (zipFn === void 0) {
zipFn = zipToList;
}
var rowBatches = this.rowMajorBatch(batchSize, smallLastBatch);
return rowBatches.map(function(x) {
return deepZip(x, zipFn);
});
};
LazyIterator2.prototype.concatenate = function(iterator, baseErrorHandler) {
return new ChainedIterator(iteratorFromItems([this, iterator]), baseErrorHandler);
};
LazyIterator2.prototype.take = function(count) {
if (count < 0 || count == null) {
return this;
}
return new TakeIterator(this, count);
};
LazyIterator2.prototype.skip = function(count) {
if (count < 0 || count == null) {
return this;
}
return new SkipIterator(this, count);
};
LazyIterator2.prototype.prefetch = function(bufferSize) {
return new PrefetchIterator(this, bufferSize);
};
LazyIterator2.prototype.shuffle = function(windowSize, seed) {
return new ShuffleIterator(this, windowSize, seed);
};
LazyIterator2.prototype.serial = function() {
return new SerialIterator(this);
};
return LazyIterator2;
}();
var ArrayIterator = function(_super) {
__extends(ArrayIterator2, _super);
function ArrayIterator2(items) {
var _this = _super.call(this) || this;
_this.items = items;
_this.trav = 0;
return _this;
}
ArrayIterator2.prototype.summary = function() {
return "Array of " + this.items.length + " items";
};
ArrayIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var item;
return __generator(this, function(_a) {
if (this.trav >= this.items.length) {
return [2, {value: null, done: true}];
}
item = this.items[this.trav];
this.trav++;
return [2, {value: deepClone(item), done: false}];
});
});
};
return ArrayIterator2;
}(LazyIterator);
var FunctionCallIterator = function(_super) {
__extends(FunctionCallIterator2, _super);
function FunctionCallIterator2(nextFn) {
var _this = _super.call(this) || this;
_this.nextFn = nextFn;
return _this;
}
FunctionCallIterator2.prototype.summary = function() {
return "Function call";
};
FunctionCallIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
try {
return [2, this.nextFn()];
} catch (e) {
e.message = "Error thrown while iterating through a dataset: " + e.message;
throw e;
}
return [2];
});
});
};
return FunctionCallIterator2;
}(LazyIterator);
var SerialIterator = function(_super) {
__extends(SerialIterator2, _super);
function SerialIterator2(upstream) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.lastRead = Promise.resolve({value: null, done: false});
return _this;
}
SerialIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> Serial";
};
SerialIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
this.lastRead = this.lastRead.then(function() {
return _this.serialNext();
});
return [2, this.lastRead];
});
});
};
SerialIterator2.prototype.serialNext = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.upstream.next()];
});
});
};
return SerialIterator2;
}(LazyIterator);
var SkipIterator = function(_super) {
__extends(SkipIterator2, _super);
function SkipIterator2(upstream, maxCount) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.maxCount = maxCount;
_this.count = 0;
_this.lastRead = Promise.resolve({value: null, done: false});
return _this;
}
SkipIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> Skip";
};
SkipIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
this.lastRead = this.lastRead.then(function() {
return _this.serialNext();
});
return [2, this.lastRead];
});
});
};
SkipIterator2.prototype.serialNext = function() {
return __awaiter(this, void 0, void 0, function() {
var skipped;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!(this.count++ < this.maxCount))
return [3, 2];
return [4, this.upstream.next()];
case 1:
skipped = _a.sent();
if (skipped.done) {
return [2, skipped];
}
tf.dispose(skipped.value);
return [3, 0];
case 2:
return [2, this.upstream.next()];
}
});
});
};
return SkipIterator2;
}(LazyIterator);
var TakeIterator = function(_super) {
__extends(TakeIterator2, _super);
function TakeIterator2(upstream, maxCount) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.maxCount = maxCount;
_this.count = 0;
return _this;
}
TakeIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> Take";
};
TakeIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
if (this.count++ >= this.maxCount) {
return [2, {value: null, done: true}];
}
return [2, this.upstream.next()];
});
});
};
return TakeIterator2;
}(LazyIterator);
var RowMajorBatchIterator = function(_super) {
__extends(RowMajorBatchIterator2, _super);
function RowMajorBatchIterator2(upstream, batchSize, enableSmallLastBatch) {
if (enableSmallLastBatch === void 0) {
enableSmallLastBatch = true;
}
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.batchSize = batchSize;
_this.enableSmallLastBatch = enableSmallLastBatch;
_this.lastRead = Promise.resolve({value: null, done: false});
return _this;
}
RowMajorBatchIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> RowMajorBatch";
};
RowMajorBatchIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
this.lastRead = this.lastRead.then(function() {
return _this.serialNext();
});
return [2, this.lastRead];
});
});
};
RowMajorBatchIterator2.prototype.serialNext = function() {
return __awaiter(this, void 0, void 0, function() {
var batch, item;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
batch = [];
_a.label = 1;
case 1:
if (!(batch.length < this.batchSize))
return [3, 3];
return [4, this.upstream.next()];
case 2:
item = _a.sent();
if (item.done) {
if (this.enableSmallLastBatch && batch.length > 0) {
return [2, {value: batch, done: false}];
}
return [2, {value: null, done: true}];
}
batch.push(item.value);
return [3, 1];
case 3:
return [2, {value: batch, done: false}];
}
});
});
};
return RowMajorBatchIterator2;
}(LazyIterator);
var FilterIterator = function(_super) {
__extends(FilterIterator2, _super);
function FilterIterator2(upstream, predicate) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.predicate = predicate;
_this.lastRead = Promise.resolve({value: null, done: false});
return _this;
}
FilterIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> Filter";
};
FilterIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
this.lastRead = this.lastRead.then(function() {
return _this.serialNext();
});
return [2, this.lastRead];
});
});
};
FilterIterator2.prototype.serialNext = function() {
return __awaiter(this, void 0, void 0, function() {
var item;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.upstream.next()];
case 1:
item = _a.sent();
if (item.done || this.predicate(item.value)) {
return [2, item];
}
tf.dispose(item.value);
return [3, 0];
case 2:
return [2];
}
});
});
};
return FilterIterator2;
}(LazyIterator);
var MapIterator = function(_super) {
__extends(MapIterator2, _super);
function MapIterator2(upstream, transform) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.transform = transform;
return _this;
}
MapIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> Map";
};
MapIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var item, inputTensors, mapped, outputTensors, _i, inputTensors_1, t;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.upstream.next()];
case 1:
item = _a.sent();
if (item.done) {
return [2, {value: null, done: true}];
}
inputTensors = tf.tensor_util.getTensorsInContainer(item.value);
mapped = this.transform(item.value);
outputTensors = tf.tensor_util.getTensorsInContainer(mapped);
for (_i = 0, inputTensors_1 = inputTensors; _i < inputTensors_1.length; _i++) {
t = inputTensors_1[_i];
if (!tf.tensor_util.isTensorInList(t, outputTensors)) {
t.dispose();
}
}
return [2, {value: mapped, done: false}];
}
});
});
};
return MapIterator2;
}(LazyIterator);
var ErrorHandlingLazyIterator = function(_super) {
__extends(ErrorHandlingLazyIterator2, _super);
function ErrorHandlingLazyIterator2(upstream, handler) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.handler = handler;
_this.count = 0;
_this.lastRead = Promise.resolve({value: null, done: false});
return _this;
}
ErrorHandlingLazyIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> handleErrors";
};
ErrorHandlingLazyIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
this.lastRead = this.lastRead.then(function() {
return _this.serialNext();
});
return [2, this.lastRead];
});
});
};
ErrorHandlingLazyIterator2.prototype.serialNext = function() {
return __awaiter(this, void 0, void 0, function() {
var e_1;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
_a.label = 1;
case 1:
_a.trys.push([1, 3, , 4]);
return [4, this.upstream.next()];
case 2:
return [2, _a.sent()];
case 3:
e_1 = _a.sent();
if (!this.handler(e_1)) {
return [2, {value: null, done: true}];
}
return [3, 4];
case 4:
return [3, 0];
case 5:
return [2];
}
});
});
};
return ErrorHandlingLazyIterator2;
}(LazyIterator);
var AsyncMapIterator = function(_super) {
__extends(AsyncMapIterator2, _super);
function AsyncMapIterator2(upstream, transform) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.transform = transform;
return _this;
}
AsyncMapIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> AsyncMap";
};
AsyncMapIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var item, inputTensors, mapped, outputTensors, _i, inputTensors_2, t;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.upstream.next()];
case 1:
item = _a.sent();
if (item.done) {
return [2, {value: null, done: true}];
}
inputTensors = tf.tensor_util.getTensorsInContainer(item.value);
return [4, this.transform(item.value)];
case 2:
mapped = _a.sent();
outputTensors = tf.tensor_util.getTensorsInContainer(mapped);
for (_i = 0, inputTensors_2 = inputTensors; _i < inputTensors_2.length; _i++) {
t = inputTensors_2[_i];
if (!tf.tensor_util.isTensorInList(t, outputTensors)) {
t.dispose();
}
}
return [2, {value: mapped, done: false}];
}
});
});
};
return AsyncMapIterator2;
}(LazyIterator);
var OneToManyIterator = function(_super) {
__extends(OneToManyIterator2, _super);
function OneToManyIterator2() {
var _this = _super.call(this) || this;
_this.outputQueue = new GrowingRingBuffer();
_this.lastRead = Promise.resolve({value: null, done: false});
return _this;
}
OneToManyIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
this.lastRead = this.lastRead.then(function() {
return _this.serialNext();
});
return [2, this.lastRead];
});
});
};
OneToManyIterator2.prototype.serialNext = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!(this.outputQueue.length() === 0))
return [3, 2];
return [4, this.pump()];
case 1:
if (!_a.sent()) {
return [2, {value: null, done: true}];
}
return [3, 0];
case 2:
return [2, {value: this.outputQueue.shift(), done: false}];
}
});
});
};
return OneToManyIterator2;
}(LazyIterator);
var FlatmapIterator = function(_super) {
__extends(FlatmapIterator2, _super);
function FlatmapIterator2(upstream, transform) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.transform = transform;
return _this;
}
FlatmapIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> Flatmap";
};
FlatmapIterator2.prototype.pump = function() {
return __awaiter(this, void 0, void 0, function() {
var item, inputTensors, mappedArray, outputTensors, _i, inputTensors_3, t;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.upstream.next()];
case 1:
item = _a.sent();
if (item.done) {
return [2, false];
}
inputTensors = tf.tensor_util.getTensorsInContainer(item.value);
mappedArray = this.transform(item.value);
outputTensors = tf.tensor_util.getTensorsInContainer(mappedArray);
this.outputQueue.pushAll(mappedArray);
for (_i = 0, inputTensors_3 = inputTensors; _i < inputTensors_3.length; _i++) {
t = inputTensors_3[_i];
if (!tf.tensor_util.isTensorInList(t, outputTensors)) {
t.dispose();
}
}
return [2, true];
}
});
});
};
return FlatmapIterator2;
}(OneToManyIterator);
var ChainedIterator = function(_super) {
__extends(ChainedIterator2, _super);
function ChainedIterator2(iterators, baseErrorHandler) {
var _this = _super.call(this) || this;
_this.baseErrorHandler = baseErrorHandler;
_this.lastRead = null;
_this.iterator = null;
_this.moreIterators = iterators;
return _this;
}
ChainedIterator2.prototype.summary = function() {
var upstreamSummaries = "TODO: fill in upstream of chained summaries";
return upstreamSummaries + " -> Chained";
};
ChainedIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
this.lastRead = this.readFromChain(this.lastRead);
return [2, this.lastRead];
});
});
};
ChainedIterator2.prototype.readFromChain = function(lastRead) {
return __awaiter(this, void 0, void 0, function() {
var iteratorResult, itemResult;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, lastRead];
case 1:
_a.sent();
if (!(this.iterator == null))
return [3, 3];
return [4, this.moreIterators.next()];
case 2:
iteratorResult = _a.sent();
if (iteratorResult.done) {
return [2, {value: null, done: true}];
}
this.iterator = iteratorResult.value;
if (this.baseErrorHandler != null) {
this.iterator = this.iterator.handleErrors(this.baseErrorHandler);
}
_a.label = 3;
case 3:
return [4, this.iterator.next()];
case 4:
itemResult = _a.sent();
if (itemResult.done) {
this.iterator = null;
return [2, this.readFromChain(lastRead)];
}
return [2, itemResult];
}
});
});
};
return ChainedIterator2;
}(LazyIterator);
var ZipMismatchMode;
(function(ZipMismatchMode2) {
ZipMismatchMode2[ZipMismatchMode2["FAIL"] = 0] = "FAIL";
ZipMismatchMode2[ZipMismatchMode2["SHORTEST"] = 1] = "SHORTEST";
ZipMismatchMode2[ZipMismatchMode2["LONGEST"] = 2] = "LONGEST";
})(ZipMismatchMode || (ZipMismatchMode = {}));
var ZipIterator = function(_super) {
__extends(ZipIterator2, _super);
function ZipIterator2(iterators, mismatchMode) {
if (mismatchMode === void 0) {
mismatchMode = ZipMismatchMode.FAIL;
}
var _this = _super.call(this) || this;
_this.iterators = iterators;
_this.mismatchMode = mismatchMode;
_this.count = 0;
_this.currentPromise = null;
return _this;
}
ZipIterator2.prototype.summary = function() {
var upstreamSummaries = "TODO: fill in upstream of zip summaries";
return "{" + upstreamSummaries + "} -> Zip";
};
ZipIterator2.prototype.nextState = function(afterState) {
return __awaiter(this, void 0, void 0, function() {
function getNext(container) {
if (container instanceof LazyIterator) {
var result = container.next();
return {
value: result.then(function(x) {
numIterators++;
if (x.done) {
iteratorsDone++;
}
return x.value;
}),
recurse: false
};
} else {
return {value: null, recurse: true};
}
}
var numIterators, iteratorsDone, mapped;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, afterState];
case 1:
_a.sent();
numIterators = 0;
iteratorsDone = 0;
return [4, deepMapAndAwaitAll(this.iterators, getNext)];
case 2:
mapped = _a.sent();
if (numIterators === iteratorsDone) {
return [2, {value: null, done: true}];
}
if (iteratorsDone > 0) {
switch (this.mismatchMode) {
case ZipMismatchMode.FAIL:
throw new Error("Zipped streams should have the same length. " + ("Mismatched at element " + this.count + "."));
case ZipMismatchMode.SHORTEST:
return [2, {value: null, done: true}];
case ZipMismatchMode.LONGEST:
}
}
this.count++;
return [2, {value: mapped, done: false}];
}
});
});
};
ZipIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
this.currentPromise = this.nextState(this.currentPromise);
return [2, this.currentPromise];
});
});
};
return ZipIterator2;
}(LazyIterator);
var PrefetchIterator = function(_super) {
__extends(PrefetchIterator2, _super);
function PrefetchIterator2(upstream, bufferSize) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.bufferSize = bufferSize;
_this.buffer = new RingBuffer(bufferSize);
return _this;
}
PrefetchIterator2.prototype.summary = function() {
return this.upstream.summary() + " -> Prefetch";
};
PrefetchIterator2.prototype.refill = function() {
while (!this.buffer.isFull()) {
var v = this.upstream.next();
this.buffer.push(v);
}
};
PrefetchIterator2.prototype.next = function() {
this.refill();
return this.buffer.shift();
};
return PrefetchIterator2;
}(LazyIterator);
var ShuffleIterator = function(_super) {
__extends(ShuffleIterator2, _super);
function ShuffleIterator2(upstream, windowSize, seed) {
var _this = _super.call(this, upstream, windowSize) || this;
_this.upstream = upstream;
_this.windowSize = windowSize;
_this.upstreamExhausted = false;
_this.random = seedrandom_1(seed || tf.util.now().toString());
_this.lastRead = Promise.resolve({value: null, done: false});
return _this;
}
ShuffleIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
this.lastRead = this.lastRead.then(function() {
return _this.serialNext();
});
return [2, this.lastRead];
});
});
};
ShuffleIterator2.prototype.randomInt = function(max) {
return Math.floor(this.random() * max);
};
ShuffleIterator2.prototype.chooseIndex = function() {
return this.randomInt(this.buffer.length());
};
ShuffleIterator2.prototype.serialNext = function() {
return __awaiter(this, void 0, void 0, function() {
var chosenIndex, result;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!this.upstreamExhausted) {
this.refill();
}
_a.label = 1;
case 1:
if (!!this.buffer.isEmpty())
return [3, 3];
chosenIndex = this.chooseIndex();
return [4, this.buffer.shuffleExcise(chosenIndex)];
case 2:
result = _a.sent();
if (result.done) {
this.upstreamExhausted = true;
} else {
this.refill();
return [2, result];
}
return [3, 1];
case 3:
return [2, {value: null, done: true}];
}
});
});
};
return ShuffleIterator2;
}(PrefetchIterator);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var Dataset = function() {
function Dataset2() {
this.size = null;
}
Dataset2.prototype.batch = function(batchSize, smallLastBatch) {
var _this = this;
if (smallLastBatch === void 0) {
smallLastBatch = true;
}
var base = this;
tf.util.assert(batchSize > 0, function() {
return "batchSize needs to be positive, but it is\n " + batchSize;
});
var size;
if (this.size === Infinity || this.size == null) {
size = this.size;
} else if (smallLastBatch) {
size = Math.ceil(this.size / batchSize);
} else {
size = Math.floor(this.size / batchSize);
}
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, base.iterator()];
case 1:
return [2, _a.sent().columnMajorBatch(batchSize, smallLastBatch, deepBatchConcat)];
}
});
});
}, size);
};
Dataset2.prototype.concatenate = function(dataset) {
var _this = this;
var base = this;
var size;
if (this.size === Infinity || dataset.size === Infinity) {
size = Infinity;
} else if (this.size != null && dataset.size != null) {
size = this.size + dataset.size;
} else {
size = null;
}
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
var _a, _b;
return __generator(this, function(_c) {
switch (_c.label) {
case 0:
return [4, base.iterator()];
case 1:
_b = (_a = _c.sent()).concatenate;
return [4, dataset.iterator()];
case 2:
return [2, _b.apply(_a, [_c.sent()])];
}
});
});
}, size);
};
Dataset2.prototype.filter = function(predicate) {
var _this = this;
var base = this;
var size;
if (this.size === Infinity) {
size = Infinity;
} else {
size = null;
}
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, base.iterator()];
case 1:
return [2, _a.sent().filter(function(x) {
return tf.tidy(function() {
return predicate(x);
});
})];
}
});
});
}, size);
};
Dataset2.prototype.forEachAsync = function(f) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.iterator()];
case 1:
return [2, _a.sent().forEachAsync(f)];
}
});
});
};
Dataset2.prototype.map = function(transform) {
var _this = this;
var base = this;
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, base.iterator()];
case 1:
return [2, _a.sent().map(function(x) {
return tf.tidy(function() {
return transform(x);
});
})];
}
});
});
}, this.size);
};
Dataset2.prototype.mapAsync = function(transform) {
var _this = this;
var base = this;
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, base.iterator()];
case 1:
return [2, _a.sent().mapAsync(transform)];
}
});
});
}, this.size);
};
Dataset2.prototype.prefetch = function(bufferSize) {
var _this = this;
if (bufferSize == null) {
throw new RangeError("`Dataset.prefetch()` requires bufferSize to be specified.");
}
var base = this;
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, base.iterator()];
case 1:
return [2, _a.sent().prefetch(bufferSize)];
}
});
});
}, this.size);
};
Dataset2.prototype.repeat = function(count) {
var _this = this;
var base = this;
var size;
if (this.size != null && count > 0) {
size = this.size * count;
} else if (count === 0) {
size = 0;
} else if (this.size != null && (count === void 0 || count < 0)) {
size = Infinity;
} else {
size = null;
}
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
var iteratorIterator;
var _this2 = this;
return __generator(this, function(_a) {
iteratorIterator = iteratorFromFunction(function() {
return __awaiter(_this2, void 0, void 0, function() {
var _a2;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
_a2 = {};
return [4, base.iterator()];
case 1:
return [2, (_a2.value = _b.sent(), _a2.done = false, _a2)];
}
});
});
});
return [2, iteratorFromConcatenated(iteratorIterator.take(count))];
});
});
}, size);
};
Dataset2.prototype.skip = function(count) {
var _this = this;
var base = this;
var size;
if (this.size != null && count >= 0 && this.size >= count) {
size = this.size - count;
} else if (this.size != null && (this.size < count || count === void 0 || count < 0)) {
size = 0;
} else {
size = null;
}
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, base.iterator()];
case 1:
return [2, _a.sent().skip(count)];
}
});
});
}, size);
};
Dataset2.prototype.shuffle = function(bufferSize, seed, reshuffleEachIteration) {
var _this = this;
if (reshuffleEachIteration === void 0) {
reshuffleEachIteration = true;
}
if (bufferSize == null || bufferSize < 0) {
if (this.size == null) {
throw new RangeError("`Dataset.shuffle()` requires bufferSize to be specified.");
} else {
throw new RangeError("`Dataset.shuffle()` requires bufferSize to be specified. If your data fits in main memory (for regular JS objects), and/or GPU memory (for `tf.Tensor`s), consider setting " + ("bufferSize to the dataset size (" + this.size + " elements)"));
}
}
var base = this;
var random = seedrandom_1(seed || tf.util.now().toString());
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
var seed2;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
seed2 = random.int32();
if (reshuffleEachIteration) {
seed2 += random.int32();
}
return [4, base.iterator()];
case 1:
return [2, _a.sent().shuffle(bufferSize, seed2.toString())];
}
});
});
}, this.size);
};
Dataset2.prototype.take = function(count) {
var _this = this;
var base = this;
var size;
if (this.size != null && this.size > count) {
size = count;
} else if (this.size != null && this.size <= count) {
size = this.size;
} else {
size = null;
}
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, base.iterator()];
case 1:
return [2, _a.sent().take(count)];
}
});
});
}, size);
};
Dataset2.prototype.toArray = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (this.size === Infinity) {
throw new Error("Can not convert infinite data stream to array.");
}
return [4, this.iterator()];
case 1:
return [2, _a.sent().toArray()];
}
});
});
};
Dataset2.prototype.toArrayForTest = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (this.size === Infinity) {
throw new Error("Can not convert infinite data stream to array.");
}
return [4, this.iterator()];
case 1:
return [2, _a.sent().toArrayForTest()];
}
});
});
};
Dataset2.MAX_BUFFER_SIZE = 1e4;
return Dataset2;
}();
function datasetFromIteratorFn(iteratorFn, size) {
if (size === void 0) {
size = null;
}
return new (function(_super) {
__extends(class_1, _super);
function class_1() {
var _this = _super !== null && _super.apply(this, arguments) || this;
_this.size = size;
return _this;
}
class_1.prototype.iterator = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, iteratorFn()];
});
});
};
return class_1;
}(Dataset))();
}
function array(items) {
var _this = this;
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, iteratorFromItems(items)];
});
});
}, items.length);
}
function zip(datasets) {
var _this = this;
if (!isIterable(datasets)) {
throw new Error("The argument to zip() must be an object or array.");
}
var size;
if (Array.isArray(datasets)) {
for (var i = 0; i < datasets.length; i++) {
size = size == null ? datasets[i].size : Math.min(size, datasets[i].size);
}
} else if (datasets instanceof Object) {
for (var ds in datasets) {
size = size == null ? datasets[ds].size : Math.min(size, datasets[ds].size);
}
}
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
var streams;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, deepMapAndAwaitAll(datasets, function(d) {
if (d instanceof Dataset) {
return {value: d.iterator(), recurse: false};
} else if (isIterable(d)) {
return {value: null, recurse: true};
} else {
throw new Error("Leaves of the structure passed to zip() must be Datasets, not primitives.");
}
})];
case 1:
streams = _a.sent();
return [2, iteratorFromZipped(streams, ZipMismatchMode.SHORTEST)];
}
});
});
}, size);
}
function deepBatchConcat(rows) {
if (rows === null) {
return null;
}
var exampleRow = rows[0];
if (canTensorify(exampleRow)) {
var value = batchConcat(rows);
return {value, recurse: false};
}
return {value: null, recurse: true};
}
function batchConcat(arrays) {
if (arrays.length === 0) {
throw new Error("Can't make a batch of zero elements.");
}
if (arrays[0] instanceof tf.Tensor) {
return tf.stack(arrays);
} else {
return tf.tensor(arrays);
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var TextLineDataset = function(_super) {
__extends(TextLineDataset2, _super);
function TextLineDataset2(input) {
var _this = _super.call(this) || this;
_this.input = input;
return _this;
}
TextLineDataset2.prototype.iterator = function() {
return __awaiter(this, void 0, void 0, function() {
var inputIterator, utf8Iterator, lineIterator;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.input.iterator()];
case 1:
inputIterator = _a.sent();
utf8Iterator = inputIterator.decodeUTF8();
lineIterator = utf8Iterator.split("\n").map(function(line) {
if (line.endsWith("\r")) {
line = line.slice(0, -1);
}
return line;
});
return [2, lineIterator];
}
});
});
};
return TextLineDataset2;
}(Dataset);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var CODE_QUOTE = '"';
var STATE_OUT = Symbol("out");
var STATE_FIELD = Symbol("field");
var STATE_QUOTE = Symbol("quote");
var STATE_QUOTE_AFTER_QUOTE = Symbol("quoteafterquote");
var STATE_WITHIN_QUOTE_IN_QUOTE = Symbol("quoteinquote");
var CSVDataset = function(_super) {
__extends(CSVDataset2, _super);
function CSVDataset2(input, csvConfig) {
var _this = _super.call(this) || this;
_this.input = input;
_this.hasHeader = true;
_this.fullColumnNames = null;
_this.columnNamesValidated = false;
_this.columnConfigs = null;
_this.configuredColumnsOnly = false;
_this.delimiter = ",";
_this.delimWhitespace = false;
_this.base = new TextLineDataset(input);
if (!csvConfig) {
csvConfig = {};
}
_this.hasHeader = csvConfig.hasHeader === false ? false : true;
_this.fullColumnNames = csvConfig.columnNames;
_this.columnConfigs = csvConfig.columnConfigs;
_this.configuredColumnsOnly = csvConfig.configuredColumnsOnly;
if (csvConfig.delimWhitespace) {
tf.util.assert(csvConfig.delimiter == null, function() {
return "Delimiter should not be provided when delimWhitespace is true.";
});
_this.delimWhitespace = true;
_this.delimiter = " ";
} else {
_this.delimiter = csvConfig.delimiter ? csvConfig.delimiter : ",";
}
return _this;
}
CSVDataset2.prototype.columnNames = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!!this.columnNamesValidated)
return [3, 2];
return [4, this.setColumnNames()];
case 1:
_a.sent();
_a.label = 2;
case 2:
return [2, this.configuredColumnsOnly ? Object.keys(this.columnConfigs) : this.fullColumnNames];
}
});
});
};
CSVDataset2.prototype.setColumnNames = function() {
return __awaiter(this, void 0, void 0, function() {
var columnNamesFromFile, counts, duplicateNames, _i, _a, key, index;
var _this = this;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
return [4, this.maybeReadHeaderLine()];
case 1:
columnNamesFromFile = _b.sent();
if (!this.fullColumnNames && !columnNamesFromFile) {
throw new Error("Column names must be provided if there is no header line.");
} else if (this.fullColumnNames && columnNamesFromFile) {
tf.util.assert(columnNamesFromFile.length === this.fullColumnNames.length, function() {
return "The length of provided columnNames (" + _this.fullColumnNames.length.toString() + ") does not match the length of the header line read from file (" + columnNamesFromFile.length.toString() + ").";
});
}
if (!this.fullColumnNames) {
this.fullColumnNames = columnNamesFromFile;
}
counts = this.fullColumnNames.reduce(function(countAcc, name) {
countAcc[name] = countAcc[name] + 1 || 1;
return countAcc;
}, {});
duplicateNames = Object.keys(counts).filter(function(name) {
return counts[name] > 1;
});
tf.util.assert(duplicateNames.length === 0, function() {
return "Duplicate column names found: " + duplicateNames.toString();
});
if (this.columnConfigs) {
for (_i = 0, _a = Object.keys(this.columnConfigs); _i < _a.length; _i++) {
key = _a[_i];
index = this.fullColumnNames.indexOf(key);
if (index === -1) {
throw new Error('The key "' + key + '" provided in columnConfigs does not match any of the column names (' + this.fullColumnNames.toString() + ").");
}
}
}
this.columnNamesValidated = true;
return [2];
}
});
});
};
CSVDataset2.prototype.maybeReadHeaderLine = function() {
return __awaiter(this, void 0, void 0, function() {
var iter, firstElement, firstLine, headers;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!this.hasHeader)
return [3, 3];
return [4, this.base.iterator()];
case 1:
iter = _a.sent();
return [4, iter.next()];
case 2:
firstElement = _a.sent();
if (firstElement.done) {
throw new Error("No data was found for CSV parsing.");
}
firstLine = firstElement.value;
headers = this.parseRow(firstLine, false);
return [2, headers];
case 3:
return [2, null];
}
});
});
};
CSVDataset2.prototype.iterator = function() {
return __awaiter(this, void 0, void 0, function() {
var lines;
var _this = this;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (!!this.columnNamesValidated)
return [3, 2];
return [4, this.setColumnNames()];
case 1:
_a.sent();
_a.label = 2;
case 2:
return [4, this.base.iterator()];
case 3:
lines = _a.sent();
if (this.hasHeader) {
lines = lines.skip(1);
}
return [2, lines.map(function(x) {
return _this.makeDataElement(x);
})];
}
});
});
};
CSVDataset2.prototype.makeDataElement = function(line) {
var values = this.parseRow(line);
var features = {};
var labels = {};
for (var i = 0; i < this.fullColumnNames.length; i++) {
var key = this.fullColumnNames[i];
var config = this.columnConfigs ? this.columnConfigs[key] : null;
if (this.configuredColumnsOnly && !config) {
continue;
} else {
var value = values[i];
var parsedValue = null;
if (value === "") {
if (config && config.default !== void 0) {
parsedValue = config.default;
} else if (config && (config.required || config.isLabel)) {
throw new Error("Required column " + key + " is empty in this line: " + line);
} else {
parsedValue = void 0;
}
} else {
var valueAsNum = Number(value);
if (isNaN(valueAsNum)) {
if (config && config.dtype === "bool") {
parsedValue = this.getBoolean(value);
} else {
parsedValue = value;
}
} else if (!config || !config.dtype) {
parsedValue = valueAsNum;
} else {
switch (config.dtype) {
case "float32":
parsedValue = valueAsNum;
break;
case "int32":
parsedValue = Math.floor(valueAsNum);
break;
case "bool":
parsedValue = this.getBoolean(value);
break;
default:
parsedValue = valueAsNum;
}
}
}
config && config.isLabel ? labels[key] = parsedValue : features[key] = parsedValue;
}
}
if (Object.keys(labels).length === 0) {
return features;
} else {
return {xs: features, ys: labels};
}
};
CSVDataset2.prototype.getBoolean = function(value) {
if (value === "1" || value.toLowerCase() === "true") {
return 1;
} else {
return 0;
}
};
CSVDataset2.prototype.parseRow = function(line, validateElementCount) {
if (validateElementCount === void 0) {
validateElementCount = true;
}
var result = [];
var readOffset = 0;
var readLength = line.length;
var currentState = STATE_OUT;
for (var i = 0; i < readLength; i++) {
switch (currentState) {
case STATE_OUT:
switch (line.charAt(i)) {
case CODE_QUOTE:
readOffset = i + 1;
currentState = STATE_QUOTE;
break;
case this.delimiter:
readOffset = i + 1;
if (this.delimiter === " " && this.delimWhitespace) {
break;
}
result.push("");
currentState = STATE_OUT;
break;
default:
currentState = STATE_FIELD;
readOffset = i;
break;
}
break;
case STATE_FIELD:
switch (line.charAt(i)) {
case this.delimiter:
result.push(line.substring(readOffset, i));
currentState = STATE_OUT;
readOffset = i + 1;
break;
}
break;
case STATE_QUOTE:
switch (line.charAt(i)) {
case CODE_QUOTE:
currentState = STATE_QUOTE_AFTER_QUOTE;
break;
}
break;
case STATE_QUOTE_AFTER_QUOTE:
switch (line.charAt(i)) {
case this.delimiter:
result.push(line.substring(readOffset, i - 1));
currentState = STATE_OUT;
readOffset = i + 1;
break;
case CODE_QUOTE:
currentState = STATE_QUOTE;
break;
default:
currentState = STATE_WITHIN_QUOTE_IN_QUOTE;
break;
}
break;
case STATE_WITHIN_QUOTE_IN_QUOTE:
switch (line.charAt(i)) {
case CODE_QUOTE:
currentState = STATE_QUOTE;
break;
}
break;
}
}
if (currentState === STATE_QUOTE_AFTER_QUOTE) {
result.push(line.substring(readOffset, readLength - 1));
} else {
result.push(line.substring(readOffset));
}
if (validateElementCount && result.length !== this.fullColumnNames.length) {
throw new Error("Invalid row in csv file. Should have " + this.fullColumnNames.length + " elements in a row, but got " + result);
}
return result;
};
return CSVDataset2;
}(Dataset);
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var MicrophoneIterator = function(_super) {
__extends(MicrophoneIterator2, _super);
function MicrophoneIterator2(microphoneConfig) {
var _this = _super.call(this) || this;
_this.microphoneConfig = microphoneConfig;
_this.isClosed = false;
_this.fftSize = microphoneConfig.fftSize || 1024;
var fftSizeLog2 = Math.log2(_this.fftSize);
if (_this.fftSize < 0 || fftSizeLog2 < 4 || fftSizeLog2 > 14 || !Number.isInteger(fftSizeLog2)) {
throw new Error("Invalid fftSize: it must be a power of 2 between " + ("2 to 4 and 2 to 14, but got " + _this.fftSize));
}
_this.numFrames = microphoneConfig.numFramesPerSpectrogram || 43;
_this.sampleRateHz = microphoneConfig.sampleRateHz;
_this.columnTruncateLength = microphoneConfig.columnTruncateLength || _this.fftSize;
_this.audioTrackConstraints = microphoneConfig.audioTrackConstraints;
_this.smoothingTimeConstant = microphoneConfig.smoothingTimeConstant || 0;
_this.includeSpectrogram = microphoneConfig.includeSpectrogram === false ? false : true;
_this.includeWaveform = microphoneConfig.includeWaveform === true ? true : false;
if (!_this.includeSpectrogram && !_this.includeWaveform) {
throw new Error("Both includeSpectrogram and includeWaveform are false. At least one type of data should be returned.");
}
return _this;
}
MicrophoneIterator2.prototype.summary = function() {
return "microphone";
};
MicrophoneIterator2.create = function(microphoneConfig) {
if (microphoneConfig === void 0) {
microphoneConfig = {};
}
return __awaiter(this, void 0, void 0, function() {
var microphoneIterator;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (tf.env().get("IS_NODE")) {
throw new Error("microphone API is only supported in browser environment.");
}
microphoneIterator = new MicrophoneIterator2(microphoneConfig);
return [4, microphoneIterator.start()];
case 1:
_a.sent();
return [2, microphoneIterator];
}
});
});
};
MicrophoneIterator2.prototype.start = function() {
return __awaiter(this, void 0, void 0, function() {
var _a, e_1, ctxConstructor, streamSource;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
_b.trys.push([0, 2, , 3]);
_a = this;
return [4, navigator.mediaDevices.getUserMedia({
audio: this.audioTrackConstraints == null ? true : this.audioTrackConstraints,
video: false
})];
case 1:
_a.stream = _b.sent();
return [3, 3];
case 2:
e_1 = _b.sent();
throw new Error("Error thrown while initializing video stream: " + e_1.message);
case 3:
if (!this.stream) {
throw new Error("Could not obtain audio from microphone.");
}
ctxConstructor = window.AudioContext || window.webkitAudioContext;
this.audioContext = new ctxConstructor();
if (!this.sampleRateHz) {
this.sampleRateHz = this.audioContext.sampleRate;
} else if (this.audioContext.sampleRate !== this.sampleRateHz) {
throw new Error("Mismatch in sampling rate: " + ("Expected: " + this.sampleRateHz + "; ") + ("Actual: " + this.audioContext.sampleRate));
}
streamSource = this.audioContext.createMediaStreamSource(this.stream);
this.analyser = this.audioContext.createAnalyser();
this.analyser.fftSize = this.fftSize * 2;
this.analyser.smoothingTimeConstant = this.smoothingTimeConstant;
streamSource.connect(this.analyser);
this.freqData = new Float32Array(this.fftSize);
this.timeData = new Float32Array(this.fftSize);
return [2];
}
});
});
};
MicrophoneIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var spectrogramTensor, waveformTensor, audioDataQueue, freqData, timeData;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (this.isClosed) {
return [2, {value: null, done: true}];
}
return [4, this.getAudioData()];
case 1:
audioDataQueue = _a.sent();
if (this.includeSpectrogram) {
freqData = this.flattenQueue(audioDataQueue.freqDataQueue);
spectrogramTensor = this.getTensorFromAudioDataArray(freqData, [this.numFrames, this.columnTruncateLength, 1]);
}
if (this.includeWaveform) {
timeData = this.flattenQueue(audioDataQueue.timeDataQueue);
waveformTensor = this.getTensorFromAudioDataArray(timeData, [this.numFrames * this.fftSize, 1]);
}
return [2, {
value: {spectrogram: spectrogramTensor, waveform: waveformTensor},
done: false
}];
}
});
});
};
MicrophoneIterator2.prototype.capture = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.next()];
case 1:
return [2, _a.sent().value];
}
});
});
};
MicrophoneIterator2.prototype.getAudioData = function() {
return __awaiter(this, void 0, void 0, function() {
var freqDataQueue, timeDataQueue, currentFrames;
var _this = this;
return __generator(this, function(_a) {
freqDataQueue = [];
timeDataQueue = [];
currentFrames = 0;
return [2, new Promise(function(resolve) {
var intervalID = setInterval(function() {
if (_this.includeSpectrogram) {
_this.analyser.getFloatFrequencyData(_this.freqData);
if (_this.freqData[0] === -Infinity) {
resolve({freqDataQueue, timeDataQueue});
}
freqDataQueue.push(_this.freqData.slice(0, _this.columnTruncateLength));
}
if (_this.includeWaveform) {
_this.analyser.getFloatTimeDomainData(_this.timeData);
timeDataQueue.push(_this.timeData.slice());
}
if (++currentFrames === _this.numFrames) {
clearInterval(intervalID);
resolve({freqDataQueue, timeDataQueue});
}
}, _this.fftSize / _this.sampleRateHz * 1e3);
})];
});
});
};
MicrophoneIterator2.prototype.stop = function() {
if (!this.isClosed) {
this.isClosed = true;
this.analyser.disconnect();
this.audioContext.close();
if (this.stream != null && this.stream.getTracks().length > 0) {
this.stream.getTracks()[0].stop();
}
}
};
MicrophoneIterator2.prototype.toArray = function() {
throw new Error("Can not convert infinite audio stream to array.");
};
MicrophoneIterator2.prototype.getSampleRate = function() {
return this.sampleRateHz;
};
MicrophoneIterator2.prototype.flattenQueue = function(queue) {
var frameSize = queue[0].length;
var freqData = new Float32Array(queue.length * frameSize);
queue.forEach(function(data, i) {
return freqData.set(data, i * frameSize);
});
return freqData;
};
MicrophoneIterator2.prototype.getTensorFromAudioDataArray = function(freqData, shape) {
var vals = new Float32Array(tf.util.sizeFromShape(shape));
vals.set(freqData, vals.length - freqData.length);
return tf.tensor(vals, shape);
};
return MicrophoneIterator2;
}(LazyIterator);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var WebcamIterator = function(_super) {
__extends(WebcamIterator2, _super);
function WebcamIterator2(webcamVideoElement, webcamConfig) {
var _this = _super.call(this) || this;
_this.webcamVideoElement = webcamVideoElement;
_this.webcamConfig = webcamConfig;
_this.isClosed = true;
_this.resize = false;
if (_this.needToResize()) {
_this.resize = true;
_this.cropSize = [_this.webcamConfig.resizeHeight, _this.webcamConfig.resizeWidth];
_this.cropBoxInd = tf.tensor1d([0], "int32");
if (_this.webcamConfig.centerCrop) {
var widthCroppingRatio = _this.webcamConfig.resizeWidth * 1 / _this.webcamVideoElement.width;
var heightCroppingRatio = _this.webcamConfig.resizeHeight * 1 / _this.webcamVideoElement.height;
var widthCropStart = (1 - widthCroppingRatio) / 2;
var heightCropStart = (1 - heightCroppingRatio) / 2;
var widthCropEnd = widthCropStart + widthCroppingRatio;
var heightCropEnd = heightCroppingRatio + heightCropStart;
_this.cropBox = tf.tensor2d([heightCropStart, widthCropStart, heightCropEnd, widthCropEnd], [1, 4]);
} else {
_this.cropBox = tf.tensor2d([0, 0, 1, 1], [1, 4]);
}
}
return _this;
}
WebcamIterator2.prototype.summary = function() {
return "webcam";
};
WebcamIterator2.create = function(webcamVideoElement, webcamConfig) {
if (webcamConfig === void 0) {
webcamConfig = {};
}
return __awaiter(this, void 0, void 0, function() {
var webcamIterator;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
if (tf.env().get("IS_NODE")) {
throw new Error("tf.data.webcam is only supported in browser environment.");
}
if (!webcamVideoElement) {
webcamVideoElement = document.createElement("video");
if (!webcamConfig.resizeWidth || !webcamConfig.resizeHeight) {
throw new Error("Please provide webcam video element, or resizeWidth and resizeHeight to create a hidden video element.");
}
webcamVideoElement.width = webcamConfig.resizeWidth;
webcamVideoElement.height = webcamConfig.resizeHeight;
}
webcamIterator = new WebcamIterator2(webcamVideoElement, webcamConfig);
return [4, webcamIterator.start()];
case 1:
_a.sent();
return [2, webcamIterator];
}
});
});
};
WebcamIterator2.prototype.start = function() {
return __awaiter(this, void 0, void 0, function() {
var _a, e_1;
var _this = this;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (this.webcamConfig.facingMode) {
tf.util.assert(this.webcamConfig.facingMode === "user" || this.webcamConfig.facingMode === "environment", function() {
return "Invalid webcam facing mode: " + _this.webcamConfig.facingMode + ". Please provide 'user' or 'environment'";
});
}
_b.label = 1;
case 1:
_b.trys.push([1, 3, , 4]);
_a = this;
return [4, navigator.mediaDevices.getUserMedia({
video: {
deviceId: this.webcamConfig.deviceId,
facingMode: this.webcamConfig.facingMode ? this.webcamConfig.facingMode : "user",
width: this.webcamVideoElement.width,
height: this.webcamVideoElement.height
}
})];
case 2:
_a.stream = _b.sent();
return [3, 4];
case 3:
e_1 = _b.sent();
e_1.message = "Error thrown while initializing video stream: " + e_1.message;
throw e_1;
case 4:
if (!this.stream) {
throw new Error("Could not obtain video from webcam.");
}
try {
this.webcamVideoElement.srcObject = this.stream;
} catch (error) {
console.log(error);
this.webcamVideoElement.src = window.URL.createObjectURL(this.stream);
}
this.webcamVideoElement.play();
this.isClosed = false;
return [2, new Promise(function(resolve) {
_this.webcamVideoElement.onloadedmetadata = function() {
resolve();
};
})];
}
});
});
};
WebcamIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var img;
return __generator(this, function(_a) {
if (this.isClosed) {
return [2, {value: null, done: true}];
}
try {
img = tf.browser.fromPixels(this.webcamVideoElement);
} catch (e) {
throw new Error("Error thrown converting video to pixels: " + JSON.stringify(e));
}
if (this.resize) {
try {
return [2, {value: this.cropAndResizeFrame(img), done: false}];
} catch (e) {
throw new Error("Error thrown cropping the video: " + e.message);
} finally {
img.dispose();
}
} else {
return [2, {value: img, done: false}];
}
return [2];
});
});
};
WebcamIterator2.prototype.needToResize = function() {
if (this.webcamConfig.resizeWidth && this.webcamConfig.resizeHeight && (this.webcamVideoElement.width !== this.webcamConfig.resizeWidth || this.webcamVideoElement.height !== this.webcamConfig.resizeHeight)) {
return true;
}
return false;
};
WebcamIterator2.prototype.cropAndResizeFrame = function(img) {
var _this = this;
return tf.tidy(function() {
var expandedImage = img.toFloat().expandDims(0);
var resizedImage;
resizedImage = tf.image.cropAndResize(expandedImage, _this.cropBox, _this.cropBoxInd, _this.cropSize, "bilinear");
var shape = resizedImage.shape;
return resizedImage.reshape(shape.slice(1));
});
};
WebcamIterator2.prototype.capture = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.next()];
case 1:
return [2, _a.sent().value];
}
});
});
};
WebcamIterator2.prototype.stop = function() {
var tracks = this.stream.getTracks();
tracks.forEach(function(track) {
return track.stop();
});
try {
this.webcamVideoElement.srcObject = null;
} catch (error) {
console.log(error);
this.webcamVideoElement.src = null;
}
this.isClosed = true;
};
WebcamIterator2.prototype.toArray = function() {
throw new Error("Can not convert infinite video stream to array.");
};
return WebcamIterator2;
}(LazyIterator);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var DataSource = function() {
function DataSource2() {
}
return DataSource2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var StringIterator = function(_super) {
__extends(StringIterator2, _super);
function StringIterator2() {
return _super !== null && _super.apply(this, arguments) || this;
}
StringIterator2.prototype.split = function(separator) {
return new SplitIterator(this, separator);
};
return StringIterator2;
}(LazyIterator);
var SplitIterator = function(_super) {
__extends(SplitIterator2, _super);
function SplitIterator2(upstream, separator) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.impl = new SplitIteratorImpl(upstream, separator);
return _this;
}
SplitIterator2.prototype.summary = function() {
return this.impl.summary();
};
SplitIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.impl.next()];
});
});
};
return SplitIterator2;
}(StringIterator);
var SplitIteratorImpl = function(_super) {
__extends(SplitIteratorImpl2, _super);
function SplitIteratorImpl2(upstream, separator) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.separator = separator;
_this.carryover = "";
return _this;
}
SplitIteratorImpl2.prototype.summary = function() {
return this.upstream.summary() + " -> Split('" + this.separator + "')";
};
SplitIteratorImpl2.prototype.pump = function() {
return __awaiter(this, void 0, void 0, function() {
var chunkResult, lines, _i, _a, line;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
return [4, this.upstream.next()];
case 1:
chunkResult = _b.sent();
if (chunkResult.done) {
if (this.carryover === "") {
return [2, false];
}
this.outputQueue.push(this.carryover);
this.carryover = "";
return [2, true];
}
lines = chunkResult.value.split(this.separator);
lines[0] = this.carryover + lines[0];
for (_i = 0, _a = lines.slice(0, -1); _i < _a.length; _i++) {
line = _a[_i];
this.outputQueue.push(line);
}
this.carryover = lines[lines.length - 1];
return [2, true];
}
});
});
};
return SplitIteratorImpl2;
}(OneToManyIterator);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var ByteChunkIterator = function(_super) {
__extends(ByteChunkIterator2, _super);
function ByteChunkIterator2() {
return _super !== null && _super.apply(this, arguments) || this;
}
ByteChunkIterator2.prototype.decodeUTF8 = function() {
return new Utf8Iterator(this);
};
return ByteChunkIterator2;
}(LazyIterator);
var Utf8Iterator = function(_super) {
__extends(Utf8Iterator2, _super);
function Utf8Iterator2(upstream) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
_this.impl = new Utf8IteratorImpl(upstream);
return _this;
}
Utf8Iterator2.prototype.summary = function() {
return this.impl.summary();
};
Utf8Iterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.impl.next()];
});
});
};
return Utf8Iterator2;
}(StringIterator);
var Utf8IteratorImpl = function(_super) {
__extends(Utf8IteratorImpl2, _super);
function Utf8IteratorImpl2(upstream) {
var _this = _super.call(this) || this;
_this.upstream = upstream;
if (tf.env().get("IS_BROWSER")) {
_this.decoder = new TextDecoder("utf-8");
} else {
var StringDecoder = require_string_decoder().StringDecoder;
_this.decoder = new StringDecoder("utf8");
}
return _this;
}
Utf8IteratorImpl2.prototype.summary = function() {
return this.upstream.summary() + " -> Utf8";
};
Utf8IteratorImpl2.prototype.pump = function() {
return __awaiter(this, void 0, void 0, function() {
var chunkResult, chunk, text;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, this.upstream.next()];
case 1:
chunkResult = _a.sent();
if (chunkResult.done) {
return [2, false];
} else {
chunk = chunkResult.value;
}
if (tf.env().get("IS_BROWSER")) {
text = this.decoder.decode(chunk, {stream: true});
} else {
text = this.decoder.write(Buffer.from(chunk.buffer));
}
this.outputQueue.push(text);
return [2, true];
}
});
});
};
return Utf8IteratorImpl2;
}(OneToManyIterator);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var FileChunkIterator = function(_super) {
__extends(FileChunkIterator2, _super);
function FileChunkIterator2(file, options) {
if (options === void 0) {
options = {};
}
var _this = _super.call(this) || this;
_this.file = file;
_this.options = options;
tf.util.assert(file instanceof Uint8Array || (tf.env().get("IS_BROWSER") ? file instanceof File || file instanceof Blob : false), function() {
return "FileChunkIterator only supports File, Blob and Uint8Array right now.";
});
_this.offset = options.offset || 0;
_this.chunkSize = options.chunkSize || 1024 * 1024;
return _this;
}
FileChunkIterator2.prototype.summary = function() {
return "FileChunks " + this.file;
};
FileChunkIterator2.prototype.next = function() {
return __awaiter(this, void 0, void 0, function() {
var chunk, _a;
var _this = this;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (this.offset >= (this.file instanceof Uint8Array ? this.file.byteLength : this.file.size)) {
return [2, {value: null, done: true}];
}
chunk = new Promise(function(resolve, reject) {
var end = _this.offset + _this.chunkSize;
if (_this.file instanceof Uint8Array) {
resolve(new Uint8Array(_this.file.slice(_this.offset, end)));
} else {
var fileReader_1 = new FileReader();
fileReader_1.onload = function(event) {
var data = fileReader_1.result;
if (data instanceof ArrayBuffer) {
data = new Uint8Array(data);
}
if (!(data instanceof Uint8Array)) {
return reject(new TypeError("FileReader returned unknown type."));
}
resolve(data);
};
fileReader_1.onabort = function(event) {
return reject(new Error("Aborted"));
};
fileReader_1.onerror = function(event) {
return reject(new Error(event.type));
};
var slice = _this.file.slice(_this.offset, end);
fileReader_1.readAsArrayBuffer(slice);
}
_this.offset = end;
});
_a = {};
return [4, chunk];
case 1:
return [2, (_a.value = _b.sent(), _a.done = false, _a)];
}
});
});
};
return FileChunkIterator2;
}(ByteChunkIterator);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
function urlChunkIterator(url, options) {
if (options === void 0) {
options = {};
}
return __awaiter(this, void 0, void 0, function() {
var urlString, requestInit, response, uint8Array, _a;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (typeof url === "string") {
urlString = url;
} else {
urlString = url.url;
requestInit = getRequestInitFromRequest(url);
}
return [4, tf.util.fetch(urlString, requestInit)];
case 1:
response = _b.sent();
if (!response.ok)
return [3, 3];
_a = Uint8Array.bind;
return [4, response.arrayBuffer()];
case 2:
uint8Array = new (_a.apply(Uint8Array, [void 0, _b.sent()]))();
return [2, new FileChunkIterator(uint8Array, options)];
case 3:
throw new Error(response.statusText);
}
});
});
}
var getRequestInitFromRequest = function(request) {
var init = {
method: request.method,
headers: request.headers,
body: request.body,
mode: request.mode,
credentials: request.credentials,
cache: request.cache,
redirect: request.redirect,
referrer: request.referrer,
integrity: request.integrity
};
return init;
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
function isLocalPath(source) {
return typeof source === "string" && source.substr(0, 7) === "file://";
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var FileDataSource = function(_super) {
__extends(FileDataSource2, _super);
function FileDataSource2(input, options) {
if (options === void 0) {
options = {};
}
var _this = _super.call(this) || this;
_this.input = input;
_this.options = options;
return _this;
}
FileDataSource2.prototype.iterator = function() {
return __awaiter(this, void 0, void 0, function() {
var fs;
return __generator(this, function(_a) {
if (isLocalPath(this.input) && tf.env().get("IS_NODE")) {
fs = require("fs");
this.input = fs.readFileSync(this.input.substr(7));
}
return [2, new FileChunkIterator(this.input, this.options)];
});
});
};
return FileDataSource2;
}(DataSource);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
var URLDataSource = function(_super) {
__extends(URLDataSource2, _super);
function URLDataSource2(url, fileOptions) {
if (fileOptions === void 0) {
fileOptions = {};
}
var _this = _super.call(this) || this;
_this.url = url;
_this.fileOptions = fileOptions;
return _this;
}
URLDataSource2.prototype.iterator = function() {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
if (isLocalPath(this.url)) {
return [2, new FileDataSource(this.url, this.fileOptions).iterator()];
} else {
return [2, urlChunkIterator(this.url, this.fileOptions)];
}
});
});
};
return URLDataSource2;
}(DataSource);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
function csv(source, csvConfig) {
if (csvConfig === void 0) {
csvConfig = {};
}
return new CSVDataset(new URLDataSource(source), csvConfig);
}
function func(f) {
var _this = this;
var iter = iteratorFromFunction(f);
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, iter];
});
});
});
}
function generator(generator2) {
var _this = this;
return datasetFromIteratorFn(function() {
return __awaiter(_this, void 0, void 0, function() {
var gen;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, generator2()];
case 1:
gen = _a.sent();
return [2, iteratorFromFunction(function() {
return gen.next();
})];
}
});
});
});
}
function webcam(webcamVideoElement, webcamConfig) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, WebcamIterator.create(webcamVideoElement, webcamConfig)];
});
});
}
function microphone(microphoneConfig) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, MicrophoneIterator.create(microphoneConfig)];
});
});
}
/** @license See the LICENSE file. */
var version = "2.7.0";
exports.CSVDataset = CSVDataset;
exports.Dataset = Dataset;
exports.FileDataSource = FileDataSource;
exports.TextLineDataset = TextLineDataset;
exports.URLDataSource = URLDataSource;
exports.array = array;
exports.csv = csv;
exports.func = func;
exports.generator = generator;
exports.microphone = microphone;
exports.version_data = version;
exports.webcam = webcam;
exports.zip = zip;
});
// node_modules/@tensorflow/tfjs-backend-cpu/node_modules/seedrandom/lib/alea.js
var require_alea = __commonJS((exports, module) => {
(function(global2, module2, define2) {
function Alea(seed) {
var me = this, mash = Mash();
me.next = function() {
var t = 2091639 * me.s0 + me.c * 23283064365386963e-26;
me.s0 = me.s1;
me.s1 = me.s2;
return me.s2 = t - (me.c = t | 0);
};
me.c = 1;
me.s0 = mash(" ");
me.s1 = mash(" ");
me.s2 = mash(" ");
me.s0 -= mash(seed);
if (me.s0 < 0) {
me.s0 += 1;
}
me.s1 -= mash(seed);
if (me.s1 < 0) {
me.s1 += 1;
}
me.s2 -= mash(seed);
if (me.s2 < 0) {
me.s2 += 1;
}
mash = null;
}
function copy(f, t) {
t.c = f.c;
t.s0 = f.s0;
t.s1 = f.s1;
t.s2 = f.s2;
return t;
}
function impl(seed, opts) {
var xg = new Alea(seed), state = opts && opts.state, prng = xg.next;
prng.int32 = function() {
return xg.next() * 4294967296 | 0;
};
prng.double = function() {
return prng() + (prng() * 2097152 | 0) * 11102230246251565e-32;
};
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
function Mash() {
var n = 4022871197;
var mash = function(data) {
data = data.toString();
for (var i = 0; i < data.length; i++) {
n += data.charCodeAt(i);
var h = 0.02519603282416938 * n;
n = h >>> 0;
h -= n;
h *= n;
n = h >>> 0;
h -= n;
n += h * 4294967296;
}
return (n >>> 0) * 23283064365386963e-26;
};
return mash;
}
if (module2 && module2.exports) {
module2.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.alea = impl;
}
})(exports, typeof module == "object" && module, typeof define == "function" && define);
});
// node_modules/@tensorflow/tfjs-backend-cpu/node_modules/seedrandom/lib/xor128.js
var require_xor128 = __commonJS((exports, module) => {
(function(global2, module2, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.next = function() {
var t = me.x ^ me.x << 11;
me.x = me.y;
me.y = me.z;
me.z = me.w;
return me.w ^= me.w >>> 19 ^ t ^ t >>> 8;
};
if (seed === (seed | 0)) {
me.x = seed;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module2 && module2.exports) {
module2.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xor128 = impl;
}
})(exports, typeof module == "object" && module, typeof define == "function" && define);
});
// node_modules/@tensorflow/tfjs-backend-cpu/node_modules/seedrandom/lib/xorwow.js
var require_xorwow = __commonJS((exports, module) => {
(function(global2, module2, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.next = function() {
var t = me.x ^ me.x >>> 2;
me.x = me.y;
me.y = me.z;
me.z = me.w;
me.w = me.v;
return (me.d = me.d + 362437 | 0) + (me.v = me.v ^ me.v << 4 ^ (t ^ t << 1)) | 0;
};
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.v = 0;
if (seed === (seed | 0)) {
me.x = seed;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
if (k == strseed.length) {
me.d = me.x << 10 ^ me.x >>> 4;
}
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
t.v = f.v;
t.d = f.d;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module2 && module2.exports) {
module2.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xorwow = impl;
}
})(exports, typeof module == "object" && module, typeof define == "function" && define);
});
// node_modules/@tensorflow/tfjs-backend-cpu/node_modules/seedrandom/lib/xorshift7.js
var require_xorshift7 = __commonJS((exports, module) => {
(function(global2, module2, define2) {
function XorGen(seed) {
var me = this;
me.next = function() {
var X = me.x, i = me.i, t, v, w;
t = X[i];
t ^= t >>> 7;
v = t ^ t << 24;
t = X[i + 1 & 7];
v ^= t ^ t >>> 10;
t = X[i + 3 & 7];
v ^= t ^ t >>> 3;
t = X[i + 4 & 7];
v ^= t ^ t << 7;
t = X[i + 7 & 7];
t = t ^ t << 13;
v ^= t ^ t << 9;
X[i] = v;
me.i = i + 1 & 7;
return v;
};
function init(me2, seed2) {
var j, w, X = [];
if (seed2 === (seed2 | 0)) {
w = X[0] = seed2;
} else {
seed2 = "" + seed2;
for (j = 0; j < seed2.length; ++j) {
X[j & 7] = X[j & 7] << 15 ^ seed2.charCodeAt(j) + X[j + 1 & 7] << 13;
}
}
while (X.length < 8)
X.push(0);
for (j = 0; j < 8 && X[j] === 0; ++j)
;
if (j == 8)
w = X[7] = -1;
else
w = X[j];
me2.x = X;
me2.i = 0;
for (j = 256; j > 0; --j) {
me2.next();
}
}
init(me, seed);
}
function copy(f, t) {
t.x = f.x.slice();
t.i = f.i;
return t;
}
function impl(seed, opts) {
if (seed == null)
seed = +new Date();
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.x)
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module2 && module2.exports) {
module2.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xorshift7 = impl;
}
})(exports, typeof module == "object" && module, typeof define == "function" && define);
});
// node_modules/@tensorflow/tfjs-backend-cpu/node_modules/seedrandom/lib/xor4096.js
var require_xor4096 = __commonJS((exports, module) => {
(function(global2, module2, define2) {
function XorGen(seed) {
var me = this;
me.next = function() {
var w = me.w, X = me.X, i = me.i, t, v;
me.w = w = w + 1640531527 | 0;
v = X[i + 34 & 127];
t = X[i = i + 1 & 127];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
v = X[i] = v ^ t;
me.i = i;
return v + (w ^ w >>> 16) | 0;
};
function init(me2, seed2) {
var t, v, i, j, w, X = [], limit = 128;
if (seed2 === (seed2 | 0)) {
v = seed2;
seed2 = null;
} else {
seed2 = seed2 + "\0";
v = 0;
limit = Math.max(limit, seed2.length);
}
for (i = 0, j = -32; j < limit; ++j) {
if (seed2)
v ^= seed2.charCodeAt((j + 32) % seed2.length);
if (j === 0)
w = v;
v ^= v << 10;
v ^= v >>> 15;
v ^= v << 4;
v ^= v >>> 13;
if (j >= 0) {
w = w + 1640531527 | 0;
t = X[j & 127] ^= v + w;
i = t == 0 ? i + 1 : 0;
}
}
if (i >= 128) {
X[(seed2 && seed2.length || 0) & 127] = -1;
}
i = 127;
for (j = 4 * 128; j > 0; --j) {
v = X[i + 34 & 127];
t = X[i = i + 1 & 127];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
X[i] = v ^ t;
}
me2.w = w;
me2.X = X;
me2.i = i;
}
init(me, seed);
}
function copy(f, t) {
t.i = f.i;
t.w = f.w;
t.X = f.X.slice();
return t;
}
;
function impl(seed, opts) {
if (seed == null)
seed = +new Date();
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.X)
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module2 && module2.exports) {
module2.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.xor4096 = impl;
}
})(exports, typeof module == "object" && module, typeof define == "function" && define);
});
// node_modules/@tensorflow/tfjs-backend-cpu/node_modules/seedrandom/lib/tychei.js
var require_tychei = __commonJS((exports, module) => {
(function(global2, module2, define2) {
function XorGen(seed) {
var me = this, strseed = "";
me.next = function() {
var b = me.b, c = me.c, d = me.d, a = me.a;
b = b << 25 ^ b >>> 7 ^ c;
c = c - d | 0;
d = d << 24 ^ d >>> 8 ^ a;
a = a - b | 0;
me.b = b = b << 20 ^ b >>> 12 ^ c;
me.c = c = c - d | 0;
me.d = d << 16 ^ c >>> 16 ^ a;
return me.a = a - b | 0;
};
me.a = 0;
me.b = 0;
me.c = 2654435769 | 0;
me.d = 1367130551;
if (seed === Math.floor(seed)) {
me.a = seed / 4294967296 | 0;
me.b = seed | 0;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 20; k++) {
me.b ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.a = f.a;
t.b = f.b;
t.c = f.c;
t.d = f.d;
return t;
}
;
function impl(seed, opts) {
var xg = new XorGen(seed), state = opts && opts.state, prng = function() {
return (xg.next() >>> 0) / 4294967296;
};
prng.double = function() {
do {
var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 4294967296, result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == "object")
copy(state, xg);
prng.state = function() {
return copy(xg, {});
};
}
return prng;
}
if (module2 && module2.exports) {
module2.exports = impl;
} else if (define2 && define2.amd) {
define2(function() {
return impl;
});
} else {
this.tychei = impl;
}
})(exports, typeof module == "object" && module, typeof define == "function" && define);
});
// node_modules/@tensorflow/tfjs-backend-cpu/node_modules/seedrandom/seedrandom.js
var require_seedrandom = __commonJS((exports, module) => {
(function(pool, math) {
var global2 = this, width = 256, chunks = 6, digits = 52, rngname = "random", startdenom = math.pow(width, chunks), significance = math.pow(2, digits), overflow = significance * 2, mask = width - 1, nodecrypto;
function seedrandom(seed, options, callback) {
var key = [];
options = options == true ? {entropy: true} : options || {};
var shortseed = mixkey(flatten(options.entropy ? [seed, tostring(pool)] : seed == null ? autoseed() : seed, 3), key);
var arc4 = new ARC4(key);
var prng = function() {
var n = arc4.g(chunks), d = startdenom, x = 0;
while (n < significance) {
n = (n + x) * width;
d *= width;
x = arc4.g(1);
}
while (n >= overflow) {
n /= 2;
d /= 2;
x >>>= 1;
}
return (n + x) / d;
};
prng.int32 = function() {
return arc4.g(4) | 0;
};
prng.quick = function() {
return arc4.g(4) / 4294967296;
};
prng.double = prng;
mixkey(tostring(arc4.S), pool);
return (options.pass || callback || function(prng2, seed2, is_math_call, state) {
if (state) {
if (state.S) {
copy(state, arc4);
}
prng2.state = function() {
return copy(arc4, {});
};
}
if (is_math_call) {
math[rngname] = prng2;
return seed2;
} else
return prng2;
})(prng, shortseed, "global" in options ? options.global : this == math, options.state);
}
math["seed" + rngname] = seedrandom;
function ARC4(key) {
var t, keylen = key.length, me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
if (!keylen) {
key = [keylen++];
}
while (i < width) {
s[i] = i++;
}
for (i = 0; i < width; i++) {
s[i] = s[j = mask & j + key[i % keylen] + (t = s[i])];
s[j] = t;
}
(me.g = function(count) {
var t2, r = 0, i2 = me.i, j2 = me.j, s2 = me.S;
while (count--) {
t2 = s2[i2 = mask & i2 + 1];
r = r * width + s2[mask & (s2[i2] = s2[j2 = mask & j2 + t2]) + (s2[j2] = t2)];
}
me.i = i2;
me.j = j2;
return r;
})(width);
}
function copy(f, t) {
t.i = f.i;
t.j = f.j;
t.S = f.S.slice();
return t;
}
;
function flatten(obj, depth) {
var result = [], typ = typeof obj, prop;
if (depth && typ == "object") {
for (prop in obj) {
try {
result.push(flatten(obj[prop], depth - 1));
} catch (e) {
}
}
}
return result.length ? result : typ == "string" ? obj : obj + "\0";
}
function mixkey(seed, key) {
var stringseed = seed + "", smear, j = 0;
while (j < stringseed.length) {
key[mask & j] = mask & (smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++);
}
return tostring(key);
}
function autoseed() {
try {
var out;
if (nodecrypto && (out = nodecrypto.randomBytes)) {
out = out(width);
} else {
out = new Uint8Array(width);
(global2.crypto || global2.msCrypto).getRandomValues(out);
}
return tostring(out);
} catch (e) {
var browser = global2.navigator, plugins = browser && browser.plugins;
return [+new Date(), global2, plugins, global2.screen, tostring(pool)];
}
}
function tostring(a) {
return String.fromCharCode.apply(0, a);
}
mixkey(math.random(), pool);
if (typeof module == "object" && module.exports) {
module.exports = seedrandom;
try {
nodecrypto = require_crypto();
} catch (ex) {
}
} else if (typeof define == "function" && define.amd) {
define(function() {
return seedrandom;
});
}
})([], Math);
});
// node_modules/@tensorflow/tfjs-backend-cpu/node_modules/seedrandom/index.js
var require_seedrandom2 = __commonJS((exports, module) => {
var alea = require_alea();
var xor128 = require_xor128();
var xorwow = require_xorwow();
var xorshift7 = require_xorshift7();
var xor4096 = require_xor4096();
var tychei = require_tychei();
var sr = require_seedrandom();
sr.alea = alea;
sr.xor128 = xor128;
sr.xorwow = xorwow;
sr.xorshift7 = xorshift7;
sr.xor4096 = xor4096;
sr.tychei = tychei;
module.exports = sr;
});
// node_modules/@tensorflow/tfjs-backend-cpu/dist/tf-backend-cpu.node.js
var require_tf_backend_cpu_node = __commonJS((exports) => {
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
"use strict";
Object.defineProperty(exports, "__esModule", {value: true});
var tf = require_tf_core_node();
var seedrandom = require_seedrandom2();
/*! *****************************************************************************
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at http://www.apache.org/licenses/LICENSE-2.0
THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
MERCHANTABLITY OR NON-INFRINGEMENT.
See the Apache Version 2.0 License for specific language governing permissions
and limitations under the License.
***************************************************************************** */
var extendStatics = function(d, b) {
extendStatics = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(d2, b2) {
d2.__proto__ = b2;
} || function(d2, b2) {
for (var p2 in b2)
if (b2.hasOwnProperty(p2))
d2[p2] = b2[p2];
};
return extendStatics(d, b);
};
function __extends(d, b) {
extendStatics(d, b);
function __() {
this.constructor = d;
}
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
}
function __awaiter(thisArg, _arguments, P, generator) {
function adopt(value) {
return value instanceof P ? value : new P(function(resolve) {
resolve(value);
});
}
return new (P || (P = Promise))(function(resolve, reject) {
function fulfilled(value) {
try {
step2(generator.next(value));
} catch (e) {
reject(e);
}
}
function rejected(value) {
try {
step2(generator["throw"](value));
} catch (e) {
reject(e);
}
}
function step2(result) {
result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected);
}
step2((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = {label: 0, sent: function() {
if (t[0] & 1)
throw t[1];
return t[1];
}, trys: [], ops: []}, f, y, t, g;
return g = {next: verb(0), throw: verb(1), return: verb(2)}, typeof Symbol === "function" && (g[Symbol.iterator] = function() {
return this;
}), g;
function verb(n) {
return function(v) {
return step2([n, v]);
};
}
function step2(op) {
if (f)
throw new TypeError("Generator is already executing.");
while (_)
try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done)
return t;
if (y = 0, t)
op = [op[0] & 2, t.value];
switch (op[0]) {
case 0:
case 1:
t = op;
break;
case 4:
_.label++;
return {value: op[1], done: false};
case 5:
_.label++;
y = op[1];
op = [0];
continue;
case 7:
op = _.ops.pop();
_.trys.pop();
continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) {
_ = 0;
continue;
}
if (op[0] === 3 && (!t || op[1] > t[0] && op[1] < t[3])) {
_.label = op[1];
break;
}
if (op[0] === 6 && _.label < t[1]) {
_.label = t[1];
t = op;
break;
}
if (t && _.label < t[2]) {
_.label = t[2];
_.ops.push(op);
break;
}
if (t[2])
_.ops.pop();
_.trys.pop();
continue;
}
op = body.call(thisArg, _);
} catch (e) {
op = [6, e];
y = 0;
} finally {
f = t = 0;
}
if (op[0] & 5)
throw op[1];
return {value: op[0] ? op[1] : void 0, done: true};
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function assertNotComplex(tensor, opName) {
if (!Array.isArray(tensor)) {
tensor = [tensor];
}
tensor.forEach(function(t) {
if (t != null) {
tf.util.assert(t.dtype !== "complex64", function() {
return opName + " does not support complex64 tensors in the CPU backend.";
});
}
});
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV3Impl = tf.kernel_impls.nonMaxSuppressionV3Impl;
var split = tf.kernel_impls.split;
var tile = tf.kernel_impls.tile;
var topkImpl = tf.kernel_impls.topkImpl;
var whereImpl = tf.kernel_impls.whereImpl;
var MathBackendCPU = function(_super) {
__extends(MathBackendCPU2, _super);
function MathBackendCPU2() {
var _this = _super.call(this) || this;
_this.blockSize = 48;
_this.firstUse = true;
_this.data = new tf.DataStorage(_this, tf.engine());
return _this;
}
MathBackendCPU2.prototype.write = function(values, shape, dtype) {
if (this.firstUse) {
this.firstUse = false;
if (tf.env().get("IS_NODE")) {
tf.backend_util.warn("\n============================\nHi there \u{1F44B}. Looks like you are running TensorFlow.js in Node.js. To speed things up dramatically, install our node backend, which binds to TensorFlow C++, by running npm i @tensorflow/tfjs-node, or npm i @tensorflow/tfjs-node-gpu if you have CUDA. Then call require('@tensorflow/tfjs-node'); (-gpu suffix for CUDA) at the start of your program. Visit https://github.com/tensorflow/tfjs-node for more details.\n============================");
}
}
var dataId = {};
this.data.set(dataId, {values, dtype, refCount: 1});
return dataId;
};
MathBackendCPU2.prototype.makeTensorInfo = function(shape, dtype, values) {
var outId;
if (dtype === "string" && values != null && values.length > 0 && tf.util.isString(values[0])) {
var encodedValues = values.map(function(d) {
return tf.util.encodeString(d);
});
outId = this.write(encodedValues, shape, dtype);
} else {
outId = this.write(values, shape, dtype);
}
return {dataId: outId, shape, dtype};
};
MathBackendCPU2.prototype.incRef = function(dataId) {
var tensorData = this.data.get(dataId);
tensorData.refCount++;
};
MathBackendCPU2.prototype.decRef = function(dataId) {
if (this.data.has(dataId)) {
var tensorData = this.data.get(dataId);
tensorData.refCount--;
}
};
MathBackendCPU2.prototype.move = function(dataId, values, shape, dtype) {
this.data.set(dataId, {values, dtype, refCount: 1});
};
MathBackendCPU2.prototype.numDataIds = function() {
return this.data.numDataIds();
};
MathBackendCPU2.prototype.read = function(dataId) {
return __awaiter(this, void 0, void 0, function() {
return __generator(this, function(_a) {
return [2, this.readSync(dataId)];
});
});
};
MathBackendCPU2.prototype.readSync = function(dataId) {
var _a = this.data.get(dataId), dtype = _a.dtype, complexTensorInfos = _a.complexTensorInfos;
if (dtype === "complex64") {
var realValues = this.readSync(complexTensorInfos.real.dataId);
var imagValues = this.readSync(complexTensorInfos.imag.dataId);
return tf.backend_util.mergeRealAndImagArrays(realValues, imagValues);
}
return this.data.get(dataId).values;
};
MathBackendCPU2.prototype.bufferSync = function(t) {
var data = this.readSync(t.dataId);
var decodedData = data;
if (t.dtype === "string") {
try {
decodedData = data.map(function(d) {
return tf.util.decodeString(d);
});
} catch (_a) {
throw new Error("Failed to decode encoded string bytes into utf-8");
}
}
return tf.buffer(t.shape, t.dtype, decodedData);
};
MathBackendCPU2.prototype.makeOutput = function(values, shape, dtype) {
var dataId = this.write(values, shape, dtype);
return tf.engine().makeTensorFromDataId(dataId, shape, dtype, this);
};
MathBackendCPU2.prototype.disposeData = function(dataId) {
if (this.data.has(dataId)) {
var complexTensorInfos = this.data.get(dataId).complexTensorInfos;
if (complexTensorInfos != null) {
this.disposeData(complexTensorInfos.real.dataId);
this.disposeData(complexTensorInfos.imag.dataId);
}
this.data.delete(dataId);
}
};
MathBackendCPU2.prototype.disposeIntermediateTensorInfo = function(tensorInfo) {
var dataId = tensorInfo.dataId;
if (this.data.has(dataId)) {
var tensorData = this.data.get(dataId);
tensorData.refCount--;
if (tensorData.refCount < 1) {
this.disposeData(dataId);
}
}
};
MathBackendCPU2.prototype.time = function(f) {
return __awaiter(this, void 0, void 0, function() {
var start, kernelMs;
return __generator(this, function(_a) {
start = tf.util.now();
f();
kernelMs = tf.util.now() - start;
return [2, {kernelMs}];
});
});
};
MathBackendCPU2.prototype.memory = function() {
return {
unreliable: true,
reasons: ["The reported memory is an upper bound. Due to automatic garbage collection, the true allocated memory may be less."]
};
};
MathBackendCPU2.prototype.stridedSlice = function(x, begin, end, strides) {
assertNotComplex(x, "stridedSlice");
var outShape = tf.slice_util.computeOutShape(begin, end, strides);
if (outShape.some(function(axis) {
return axis === 0;
})) {
return tf.tensor([], outShape);
}
var buffer = tf.buffer(outShape, x.dtype);
var xBuf = this.bufferSync(x);
for (var i = 0; i < buffer.size; i++) {
var loc = buffer.indexToLoc(i);
var newLoc = new Array(loc.length);
for (var j = 0; j < newLoc.length; j++) {
newLoc[j] = loc[j] * strides[j] + begin[j];
}
buffer.set.apply(buffer, [xBuf.get.apply(xBuf, newLoc)].concat(loc));
}
return buffer.toTensor();
};
MathBackendCPU2.prototype.diag = function(x) {
var xVals = this.readSync(x.dataId);
var buffer = tf.buffer([x.size, x.size], x.dtype);
var vals = buffer.values;
for (var i = 0; i < xVals.length; i++) {
vals[i * x.size + i] = xVals[i];
}
return buffer.toTensor();
};
MathBackendCPU2.prototype.unstack = function(x, axis) {
var num = x.shape[axis];
var outShape = new Array(x.rank - 1);
var outIndex = 0;
for (var i = 0; i < x.rank; i++) {
if (i !== axis) {
outShape[outIndex++] = x.shape[i];
}
}
var begin = new Array(x.rank).fill(0);
var size = x.shape.slice();
size[axis] = 1;
var res = new Array(num);
for (var i = 0; i < res.length; i++) {
begin[axis] = i;
res[i] = tf.slice(x, begin, size).reshape(outShape);
}
return res;
};
MathBackendCPU2.prototype.reverse = function(x, axis) {
assertNotComplex(x, "reverse");
var buffer = tf.buffer(x.shape, x.dtype);
var xBuf = this.bufferSync(x);
var _loop_1 = function(i2) {
var outLoc = buffer.indexToLoc(i2);
var inLoc = outLoc.slice();
axis.forEach(function(ax) {
return inLoc[ax] = x.shape[ax] - 1 - inLoc[ax];
});
buffer.set.apply(buffer, [xBuf.get.apply(xBuf, inLoc)].concat(outLoc));
};
for (var i = 0; i < buffer.size; i++) {
_loop_1(i);
}
return buffer.toTensor();
};
MathBackendCPU2.prototype.neg = function(x) {
assertNotComplex(x, "neg");
return tf.mul(tf.scalar(-1), x);
};
MathBackendCPU2.prototype.addN = function(tensors) {
var _this = this;
assertNotComplex(tensors, "addN");
var vals = tensors.map(function(t) {
return _this.readSync(t.dataId);
});
var result = tf.buffer(tensors[0].shape, tensors[0].dtype);
var resultVals = result.values;
for (var i = 0; i < tensors.length; i++) {
var currVals = vals[i];
for (var j = 0; j < resultVals.length; j++) {
resultVals[j] += currVals[j];
}
}
return result.toTensor();
};
MathBackendCPU2.prototype.softmax = function(logits, dim) {
var axes = tf.util.parseAxisParam([dim], logits.shape);
var maxLogit = tf.max(logits, axes);
var expandedShape = tf.backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
var a = tf.sub(logits, maxLogit.reshape(expandedShape));
var b = tf.exp(a);
var sumExp = this.sum(b, axes).reshape(expandedShape);
return tf.div(b, sumExp);
};
MathBackendCPU2.prototype.pow = function(a, b) {
assertNotComplex([a, b], "pow");
return this.broadcastedBinaryOp(a, b, a.dtype, function(aValue, bValue) {
return Math.pow(aValue, bValue);
});
};
MathBackendCPU2.prototype.floorDiv = function(a, b) {
assertNotComplex([a, b], "floorDiv");
var op = function(a6, b2) {
return Math.floor(a6 / b2);
};
var outputDtype = "int32";
return this.broadcastedBinaryOp(a, b, outputDtype, op);
};
MathBackendCPU2.prototype.sum = function(x, axes) {
assertNotComplex(x, "sum");
tf.backend_util.assertAxesAreInnerMostDims("sum", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var resultDtype = tf.upcastType(x.dtype, "int32");
var result = tf.zeros(outShape, resultDtype);
var reduceSize = tf.util.sizeFromShape(reduceShape);
var vals = this.readSync(result.dataId);
var aVals = this.readSync(x.dataId);
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var sum = 0;
for (var j = 0; j < reduceSize; ++j) {
sum += aVals[offset + j];
}
vals[i] = sum;
}
return result;
};
MathBackendCPU2.prototype.prod = function(x, axes) {
assertNotComplex(x, "sum");
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var resultDtype = tf.upcastType(x.dtype, "int32");
var result = tf.zeros(outShape, resultDtype);
var reduceSize = tf.util.sizeFromShape(reduceShape);
var vals = this.readSync(result.dataId);
var aVals = this.readSync(x.dataId);
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var prod = 1;
for (var j = 0; j < reduceSize; ++j) {
prod *= aVals[offset + j];
}
vals[i] = prod;
}
return result;
};
MathBackendCPU2.prototype.unsortedSegmentSum = function(x, segmentIds, numSegments) {
assertNotComplex(x, "unsortedSegmentSum");
var res = [];
var numIters = x.rank - segmentIds.rank;
for (var i = 0; i < numIters; ++i) {
segmentIds = segmentIds.expandDims(i + 1);
}
for (var i = 0; i < numSegments; ++i) {
var segmentId = tf.scalar(i, "int32");
var mask = tf.equal(segmentId, segmentIds).asType("float32");
var sum = mask.mul(x).sum(0);
res.push(sum);
}
return tf.stack(res);
};
MathBackendCPU2.prototype.argMin = function(x, axis) {
assertNotComplex(x, "argMin");
var axes = [axis];
tf.backend_util.assertAxesAreInnerMostDims("argMin", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var result = tf.zeros(outShape, "int32");
var reduceSize = tf.util.sizeFromShape(reduceShape);
var vals = this.readSync(result.dataId);
var aVals = this.readSync(x.dataId);
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var min = aVals[offset];
var minIndex = 0;
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
if (value < min) {
min = value;
minIndex = j;
}
}
vals[i] = minIndex;
}
return result;
};
MathBackendCPU2.prototype.argMax = function(x, axis) {
assertNotComplex(x, "argMax");
var axes = [axis];
tf.backend_util.assertAxesAreInnerMostDims("argMax", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var result = tf.zeros(outShape, "int32");
var reduceSize = tf.util.sizeFromShape(reduceShape);
var vals = this.readSync(result.dataId);
var aVals = this.readSync(x.dataId);
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var max_1 = aVals[offset];
var maxIndex = 0;
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
if (value > max_1) {
max_1 = value;
maxIndex = j;
}
}
vals[i] = maxIndex;
}
return result;
};
MathBackendCPU2.prototype.cumsum = function(x, axis, exclusive, reverse) {
assertNotComplex(x, "cumsum");
if (axis !== x.rank - 1) {
throw new Error("backend.cumsum in CPU expects an inner-most axis=" + (x.rank - 1) + " " + ("but got axis=" + axis));
}
var resultDtype = tf.upcastType(x.dtype, "int32");
var result = tf.zeros(x.shape, resultDtype);
var vals = this.readSync(result.dataId);
var aVals = this.readSync(x.dataId);
var finalDim = x.shape[x.rank - 1];
var indexAdjuster = reverse ? function(i2, j2) {
return i2 + finalDim - j2 - 1;
} : function(i2, j2) {
return i2 + j2;
};
for (var i = 0; i < aVals.length; i += finalDim) {
for (var j = 0; j < finalDim; j++) {
var idx = indexAdjuster(i, j);
if (j === 0) {
vals[idx] = exclusive ? 0 : aVals[idx];
} else {
var prevIdx = indexAdjuster(i, j - 1);
vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] : aVals[idx] + vals[prevIdx];
}
}
}
return result;
};
MathBackendCPU2.prototype.equal = function(a, b) {
assertNotComplex([a, b], "equal");
return this.broadcastedBinaryOp(a, b, "bool", function(aVal, bVal) {
return aVal === bVal ? 1 : 0;
});
};
MathBackendCPU2.prototype.notEqual = function(a, b) {
assertNotComplex([a, b], "notEqual");
return this.broadcastedBinaryOp(a, b, "bool", function(aVal, bVal) {
return aVal !== bVal ? 1 : 0;
});
};
MathBackendCPU2.prototype.less = function(a, b) {
assertNotComplex([a, b], "less");
return this.broadcastedBinaryOp(a, b, "bool", function(aVal, bVal) {
return aVal < bVal ? 1 : 0;
});
};
MathBackendCPU2.prototype.lessEqual = function(a, b) {
assertNotComplex([a, b], "lessEqual");
return this.broadcastedBinaryOp(a, b, "bool", function(aVal, bVal) {
return aVal <= bVal ? 1 : 0;
});
};
MathBackendCPU2.prototype.greater = function(a, b) {
assertNotComplex([a, b], "greater");
return this.broadcastedBinaryOp(a, b, "bool", function(aVal, bVal) {
return aVal > bVal ? 1 : 0;
});
};
MathBackendCPU2.prototype.greaterEqual = function(a, b) {
assertNotComplex([a, b], "greaterEqual");
return this.broadcastedBinaryOp(a, b, "bool", function(aVal, bVal) {
return aVal >= bVal ? 1 : 0;
});
};
MathBackendCPU2.prototype.logicalAnd = function(a, b) {
assertNotComplex([a, b], "logicalAnd");
return this.broadcastedBinaryOp(a, b, "bool", function(aVal, bVal) {
return aVal && bVal;
});
};
MathBackendCPU2.prototype.logicalOr = function(a, b) {
assertNotComplex([a, b], "logicalOr");
return this.broadcastedBinaryOp(a, b, "bool", function(aVal, bVal) {
return aVal || bVal;
});
};
MathBackendCPU2.prototype.select = function(condition, a, b) {
assertNotComplex([condition, a, b], "select");
var values = this.readSync(condition.dataId);
var aValues = this.readSync(a.dataId);
var bValues = this.readSync(b.dataId);
var result = tf.zeros(a.shape, tf.upcastType(a.dtype, b.dtype));
var newValues = this.readSync(result.dataId);
var index = 0;
var offset = condition.rank === 0 || condition.rank > 1 || a.rank === 1 ? 1 : tf.util.sizeFromShape(a.shape.slice(1));
for (var i = 0; i < values.length; i++) {
for (var j = 0; j < offset; j++) {
if (values[i] === 1) {
newValues[index++] = aValues[i];
} else {
newValues[index++] = bValues[i];
}
}
}
return result;
};
MathBackendCPU2.prototype.where = function(condition) {
assertNotComplex([condition], "where");
var condVals = this.readSync(condition.dataId);
return whereImpl(condition.shape, condVals);
};
MathBackendCPU2.prototype.topk = function(x, k, sorted) {
assertNotComplex(x, "topk");
var xVals = this.readSync(x.dataId);
return topkImpl(xVals, x.shape, x.dtype, k, sorted);
};
MathBackendCPU2.prototype.min = function(x, axes) {
assertNotComplex(x, "min");
tf.backend_util.assertAxesAreInnerMostDims("min", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var result = tf.zeros(outShape, x.dtype);
var reduceSize = tf.util.sizeFromShape(reduceShape);
var vals = this.readSync(result.dataId);
var aVals = this.readSync(x.dataId);
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var min = aVals[offset];
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
if (value < min) {
min = value;
}
}
vals[i] = min;
}
return result;
};
MathBackendCPU2.prototype.minimum = function(a, b) {
assertNotComplex([a, b], "minimum");
return this.broadcastedBinaryOp(a, b, a.dtype, function(aVal, bVal) {
return Math.min(aVal, bVal);
});
};
MathBackendCPU2.prototype.mod = function(a, b) {
assertNotComplex([a, b], "mod");
return this.broadcastedBinaryOp(a, b, a.dtype, function(aVal, bVal) {
var rem = aVal % bVal;
if (aVal < 0 && bVal < 0 || aVal >= 0 && bVal >= 0) {
return rem;
} else {
return (rem + bVal) % bVal;
}
});
};
MathBackendCPU2.prototype.maximum = function(a, b) {
assertNotComplex([a, b], "maximum");
return this.broadcastedBinaryOp(a, b, a.dtype, function(aVal, bVal) {
return Math.max(aVal, bVal);
});
};
MathBackendCPU2.prototype.all = function(x, axes) {
assertNotComplex(x, "all");
tf.backend_util.assertAxesAreInnerMostDims("all", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var result = tf.zeros(outShape, x.dtype);
var reduceSize = tf.util.sizeFromShape(reduceShape);
var vals = this.readSync(result.dataId);
var aVals = this.readSync(x.dataId);
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var all = aVals[offset];
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
all = all && value;
}
vals[i] = all;
}
return result;
};
MathBackendCPU2.prototype.any = function(x, axes) {
assertNotComplex(x, "any");
tf.backend_util.assertAxesAreInnerMostDims("any", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var result = tf.zeros(outShape, x.dtype);
var reduceSize = tf.util.sizeFromShape(reduceShape);
var vals = this.readSync(result.dataId);
var aVals = this.readSync(x.dataId);
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var anyVal = aVals[offset];
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
anyVal = anyVal || value;
}
vals[i] = anyVal;
}
return result;
};
MathBackendCPU2.prototype.squaredDifference = function(a, b) {
assertNotComplex([a, b], "squaredDifference");
return this.broadcastedBinaryOp(a, b, a.dtype, function(aVal, bVal) {
var diff = aVal - bVal;
return diff * diff;
});
};
MathBackendCPU2.prototype.eluDer = function(dy, y) {
assertNotComplex([dy, y], "eluDer");
var resultValues = new Float32Array(y.size);
var values = this.readSync(y.dataId);
var dyValues = this.readSync(dy.dataId);
for (var i = 0; i < values.length; ++i) {
var v = values[i];
if (v >= 1) {
resultValues[i] = dyValues[i];
} else {
resultValues[i] = dyValues[i] * (v + 1);
}
}
return this.makeOutput(resultValues, y.shape, "float32");
};
MathBackendCPU2.prototype.atan2 = function(a, b) {
assertNotComplex([a, b], "atan2");
return this.broadcastedBinaryOp(a, b, a.dtype, function(aValue, bValue) {
return Math.atan2(aValue, bValue);
});
};
MathBackendCPU2.prototype.tile = function(x, reps) {
assertNotComplex(x, "tile");
return tile(this.bufferSync(x), reps);
};
MathBackendCPU2.prototype.gather = function(x, indices, axis) {
assertNotComplex([x, indices], "gather");
var newShape = x.shape.slice();
var indicesValues = this.readSync(indices.dataId);
newShape[axis] = indicesValues.length;
var result = tf.buffer(newShape, x.dtype);
var xBuf = this.bufferSync(x);
for (var i = 0; i < result.size; ++i) {
var newLoc = result.indexToLoc(i);
var originalLoc = newLoc.slice();
originalLoc[axis] = indicesValues[newLoc[axis]];
var originalIndex = xBuf.locToIndex(originalLoc);
result.values[i] = xBuf.values[originalIndex];
}
return result.toTensor();
};
MathBackendCPU2.prototype.batchToSpaceND = function(x, blockShape, crops) {
assertNotComplex([x], "batchToSpaceND");
var prod = blockShape.reduce(function(a, b) {
return a * b;
});
var reshaped = tf.backend_util.getReshaped(x.shape, blockShape, prod);
var permuted = tf.backend_util.getPermuted(reshaped.length, blockShape.length);
var reshapedPermuted = tf.backend_util.getReshapedPermuted(x.shape, blockShape, prod);
var sliceBeginCoords = tf.backend_util.getSliceBeginCoords(crops, blockShape.length);
var sliceSize = tf.backend_util.getSliceSize(reshapedPermuted, crops, blockShape.length);
return tf.transpose(x.reshape(reshaped), permuted).reshape(reshapedPermuted).slice(sliceBeginCoords, sliceSize);
};
MathBackendCPU2.prototype.pool3d = function(x, convInfo, poolType) {
assertNotComplex(x, "pool3d");
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var initialValue = poolType === "max" ? Number.NEGATIVE_INFINITY : Number.POSITIVE_INFINITY;
var xValues = this.readSync(x.dataId);
var output = tf.buffer(convInfo.outShape, x.dtype);
var outputVals = output.values;
var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
var outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
var outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
var outputColStrides = convInfo.outShape[4];
for (var batch = 0; batch < convInfo.batchSize; ++batch) {
var outputBatchOffset = batch * outputBatchStrides;
var inputBatchOffset = batch * x.strides[0];
for (var channel = 0; channel < convInfo.inChannels; ++channel) {
for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
var xDepthCorner = yDepth * strideDepth - padFront;
var xDepthMin = xDepthCorner;
while (xDepthMin < 0) {
xDepthMin += dilationDepth;
}
var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
var outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) {
var xRowCorner = yRow * strideHeight - padTop;
var xRowMin = xRowCorner;
while (xRowMin < 0) {
xRowMin += dilationHeight;
}
var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
var outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) {
var xColCorner = yCol * strideWidth - padLeft;
var xColMin = xColCorner;
while (xColMin < 0) {
xColMin += dilationWidth;
}
var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
var outputColOffset = outputRowOffset + yCol * outputColStrides;
var minMaxValue = initialValue;
var avgValue = 0;
var count = 0;
for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
var xDepthOffset = inputBatchOffset + xDepth * x.strides[1];
for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
var xRowOffset = xDepthOffset + xRow * x.strides[2];
for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
var xColOffset = xRowOffset + xCol * x.strides[3];
var pixel = xValues[xColOffset + channel];
if (poolType === "max" && pixel > minMaxValue) {
minMaxValue = pixel;
} else if (poolType === "avg") {
avgValue += pixel;
count++;
}
if (isNaN(minMaxValue)) {
break;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
var outputOffset = outputColOffset + channel;
outputVals[outputOffset] = poolType === "avg" ? avgValue / count : minMaxValue;
}
}
}
}
}
return output.toTensor();
};
MathBackendCPU2.prototype.avgPool3d = function(x, convInfo) {
assertNotComplex(x, "avgPool3d");
return this.pool3d(x, convInfo, "avg").toFloat();
};
MathBackendCPU2.prototype.avgPool3dBackprop = function(dy, x, convInfo) {
assertNotComplex([dy, x], "avgPool3dBackprop");
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var dx = tf.buffer(x.shape, "float32");
var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
var dyBuf = this.bufferSync(dy);
for (var batch = 0; batch < convInfo.batchSize; ++batch) {
for (var channel = 0; channel < convInfo.inChannels; ++channel) {
for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
var dyDepthCorner = dxDepth - padFront;
var dyRowCorner = dxRow - padTop;
var dyColCorner = dxCol - padLeft;
var dotProd = 0;
for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
var dyDepth = (dyDepthCorner + wDepth) / strideDepth;
if (dyDepth < 0 || dyDepth >= convInfo.outDepth || Math.floor(dyDepth) !== dyDepth) {
continue;
}
for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
var dyRow = (dyRowCorner + wRow) / strideHeight;
if (dyRow < 0 || dyRow >= convInfo.outHeight || Math.floor(dyRow) !== dyRow) {
continue;
}
for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
var dyCol = (dyColCorner + wCol) / strideWidth;
if (dyCol < 0 || dyCol >= convInfo.outWidth || Math.floor(dyCol) !== dyCol) {
continue;
}
var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
dotProd += pixel;
}
}
}
dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
}
}
}
}
}
return dx.toTensor();
};
MathBackendCPU2.prototype.maxPool3d = function(x, convInfo) {
assertNotComplex(x, "maxPool3d");
return this.pool3d(x, convInfo, "max").toFloat();
};
MathBackendCPU2.prototype.maxPool3dPositions = function(x, convInfo) {
var maxPositions = tf.buffer(convInfo.outShape, "int32");
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var xBuf = this.bufferSync(x);
for (var batch = 0; batch < convInfo.batchSize; ++batch) {
for (var channel = 0; channel < convInfo.inChannels; ++channel) {
for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
var xDepthCorner = yDepth * strideDepth - padFront;
var xDepthMin = xDepthCorner;
while (xDepthMin < 0) {
xDepthMin += dilationDepth;
}
var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) {
var xRowCorner = yRow * strideHeight - padTop;
var xRowMin = xRowCorner;
while (xRowMin < 0) {
xRowMin += dilationHeight;
}
var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) {
var xColCorner = yCol * strideWidth - padLeft;
var xColMin = xColCorner;
while (xColMin < 0) {
xColMin += dilationWidth;
}
var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
var maxValue = Number.NEGATIVE_INFINITY;
var maxPosition = -1;
for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
var wDepth = xDepth - xDepthCorner;
for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
var wRow = xRow - xRowCorner;
for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
var wCol = xCol - xColCorner;
var pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
if (pixel >= maxValue) {
maxValue = pixel;
maxPosition = wDepth * effectiveFilterHeight * effectiveFilterWidth + wRow * effectiveFilterHeight + wCol;
}
}
}
}
maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
}
}
}
}
}
return maxPositions.toTensor();
};
MathBackendCPU2.prototype.maxPool3dBackprop = function(dy, x, y, convInfo) {
assertNotComplex([x, y], "maxPool3dBackprop");
var maxPositions = this.maxPool3dPositions(x, convInfo);
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var dx = tf.buffer(x.shape, "float32");
var maxPosBuf = this.bufferSync(maxPositions);
var dyBuf = this.bufferSync(dy);
for (var batch = 0; batch < convInfo.batchSize; ++batch) {
for (var channel = 0; channel < convInfo.inChannels; ++channel) {
for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
var dyDepthCorner = dxDepth - padFront;
var dyRowCorner = dxRow - padTop;
var dyColCorner = dxCol - padLeft;
var dotProd = 0;
for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
var dyDepth = (dyDepthCorner + wDepth) / strideDepth;
if (dyDepth < 0 || dyDepth >= convInfo.outDepth || Math.floor(dyDepth) !== dyDepth) {
continue;
}
for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
var dyRow = (dyRowCorner + wRow) / strideHeight;
if (dyRow < 0 || dyRow >= convInfo.outHeight || Math.floor(dyRow) !== dyRow) {
continue;
}
for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
var dyCol = (dyColCorner + wCol) / strideWidth;
if (dyCol < 0 || dyCol >= convInfo.outWidth || Math.floor(dyCol) !== dyCol) {
continue;
}
var maxPos = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1 - maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
var curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth + wRow * effectiveFilterWidth + wCol;
var mask = maxPos === curPos ? 1 : 0;
if (mask === 0) {
continue;
}
var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
dotProd += pixel * mask;
}
}
}
dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
}
}
}
}
}
return dx.toTensor();
};
MathBackendCPU2.prototype.resizeBilinear = function(x, newHeight, newWidth, alignCorners) {
assertNotComplex(x, "resizeBilinear");
var _a = x.shape, batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], numChannels = _a[3];
var xValues = this.readSync(x.dataId);
var result = new Float32Array(tf.util.sizeFromShape([batch, newHeight, newWidth, numChannels]));
var effectiveInputSize = [
alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight,
alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth
];
var effectiveOutputSize = [
alignCorners && newHeight > 1 ? newHeight - 1 : newHeight,
alignCorners && newWidth > 1 ? newWidth - 1 : newWidth
];
var outputIdx = 0;
var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
for (var b = 0; b < batch; b++) {
for (var r = 0; r < newHeight; r++) {
var sourceFracRow = effectiveRowSizeRatio * r;
var sourceRowFloor = Math.floor(sourceFracRow);
var rowFrac = sourceFracRow - sourceRowFloor;
var sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
var topRowOffset = b * x.strides[0] + sourceRowFloor * x.strides[1];
var botRowOffset = b * x.strides[0] + sourceRowCeil * x.strides[1];
for (var c = 0; c < newWidth; c++) {
var sourceFracCol = effectiveColSizeRatio * c;
var sourceColFloor = Math.floor(sourceFracCol);
var colFrac = sourceFracCol - sourceColFloor;
var sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
var topLeftOffest = topRowOffset + sourceColFloor * x.strides[2];
var botLeftOffset = botRowOffset + sourceColFloor * x.strides[2];
var topRightOffset = topRowOffset + sourceColCeil * x.strides[2];
var botRightOffest = botRowOffset + sourceColCeil * x.strides[2];
for (var d = 0; d < numChannels; d++) {
var topLeft = xValues[topLeftOffest + d];
var bottomLeft = xValues[botLeftOffset + d];
var topRight = xValues[topRightOffset + d];
var bottomRight = xValues[botRightOffest + d];
var top_1 = topLeft + (topRight - topLeft) * colFrac;
var bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
var newValue = top_1 + (bottom - top_1) * rowFrac;
result[outputIdx++] = newValue;
}
}
}
}
return tf.tensor(result, [batch, newHeight, newWidth, numChannels]);
};
MathBackendCPU2.prototype.resizeBilinearBackprop = function(dy, x, alignCorners) {
assertNotComplex([dy, x], "resizeBilinearBackprop");
var _a = x.shape, batch = _a[0], xHeight = _a[1], xWidth = _a[2], depth = _a[3];
var _b = dy.shape, yHeight = _b[1], yWidth = _b[2];
var output = new Float32Array(batch * xHeight * xWidth * depth);
var effectiveXSize = [
alignCorners && yHeight > 1 ? xHeight - 1 : xHeight,
alignCorners && yWidth > 1 ? xWidth - 1 : xWidth
];
var effectiveYSize = [
alignCorners && yHeight > 1 ? yHeight - 1 : yHeight,
alignCorners && yWidth > 1 ? yWidth - 1 : yWidth
];
var heightScale = effectiveXSize[0] / effectiveYSize[0];
var widthScale = effectiveXSize[1] / effectiveYSize[1];
var dyValues = this.readSync(dy.dataId);
var offset = 0;
for (var b = 0; b < batch; b++) {
var bOffset = b * x.strides[0];
for (var r = 0; r < yHeight; r++) {
var dxR = r * heightScale;
var topDxRIndex = Math.floor(dxR);
var bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
var topDxROffset = bOffset + topDxRIndex * x.strides[1];
var bottomDxROffset = bOffset + bottomDxRIndex * x.strides[1];
var dxRLerp = dxR - topDxRIndex;
var inverseDxRLerp = 1 - dxRLerp;
for (var c = 0; c < yWidth; c++) {
var dxC = c * widthScale;
var leftDxCIndex = Math.floor(dxC);
var rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
var dxCLerp = dxC - leftDxCIndex;
var inverseDxCLerp = 1 - dxCLerp;
var topLeftRCOffset = topDxROffset + leftDxCIndex * x.strides[2];
var topRightRCOffset = topDxROffset + rightDxCIndex * x.strides[2];
var bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * x.strides[2];
var bottomRightRCOffset = bottomDxROffset + rightDxCIndex * x.strides[2];
var inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
var inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
var dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
var dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
for (var d = 0; d < depth; d++) {
var dyVal = dyValues[offset++];
output[topLeftRCOffset + d] += dyVal * inverseDxRLerpTimesInverseDxCLerp;
output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
output[bottomLeftRCOffset + d] += dyVal * dxRLerpTimesInverseDxCLerp;
output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
}
}
}
}
return tf.tensor4d(output, [batch, xWidth, xHeight, depth], x.dtype);
};
MathBackendCPU2.prototype.resizeNearestNeighbor = function(x, newHeight, newWidth, alignCorners) {
assertNotComplex(x, "resizeNearestNeighbor");
var _a = x.shape, batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], numChannels = _a[3];
var xValues = this.readSync(x.dataId);
var output = new Float32Array(batch * newHeight * newWidth * numChannels);
var effectiveInputSize = [
alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight,
alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth
];
var effectiveOutputSize = [
alignCorners && newHeight > 1 ? newHeight - 1 : newHeight,
alignCorners && newWidth > 1 ? newWidth - 1 : newWidth
];
var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
var outputOffset = 0;
for (var b = 0; b < batch; b++) {
var batchOffset = b * x.strides[0];
for (var r = 0; r < newHeight; r++) {
var sourceFracRow = effectiveRowSizeRatio * r;
var sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
var rowOffset = batchOffset + sourceNearestRow * x.strides[1];
for (var c = 0; c < newWidth; c++) {
var sourceFracCol = effectiveColSizeRatio * c;
var sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) : Math.floor(sourceFracCol));
var colOffset = rowOffset + sourceNearestCol * x.strides[2];
for (var d = 0; d < numChannels; d++) {
var newVal = xValues[colOffset + d];
output[outputOffset++] = newVal;
}
}
}
}
return tf.tensor(output, [batch, newHeight, newWidth, numChannels], x.dtype);
};
MathBackendCPU2.prototype.resizeNearestNeighborBackprop = function(dy, x, alignCorners) {
assertNotComplex([dy, x], "resizeNearestNeighborBackprop");
var _a = x.shape, batch = _a[0], xHeight = _a[1], xWidth = _a[2], depth = _a[3];
var _b = dy.shape, yHeight = _b[1], yWidth = _b[2];
var output = new Float32Array(batch * xHeight * xWidth * depth);
var dyValues = this.readSync(dy.dataId);
var effectiveXSize = [
alignCorners && yHeight > 1 ? xHeight - 1 : xHeight,
alignCorners && yWidth > 1 ? xWidth - 1 : xWidth
];
var effectiveYSize = [
alignCorners && yHeight > 1 ? yHeight - 1 : yHeight,
alignCorners && yWidth > 1 ? yWidth - 1 : yWidth
];
var heightScale = effectiveXSize[0] / effectiveYSize[0];
var widthScale = effectiveXSize[1] / effectiveYSize[1];
var invHeightScale = 1 / heightScale;
var invWidthScale = 1 / widthScale;
var winHeight = Math.ceil(invHeightScale) * 2 + 2;
var winWidth = Math.ceil(invWidthScale) * 2 + 2;
for (var b = 0; b < batch; b++) {
var batchOffset = b * x.strides[0];
for (var r = 0; r < xHeight; r++) {
var rowOffset = batchOffset + r * x.strides[1];
var startRLerp = Math.floor(r * invHeightScale);
var startDyR = Math.floor(startRLerp - winHeight / 2);
for (var c = 0; c < xWidth; c++) {
var colOffset = rowOffset + c * x.strides[2];
var startCLerp = Math.floor(c * invWidthScale);
var startDyC = Math.floor(startCLerp - winWidth / 2);
for (var d = 0; d < depth; d++) {
var accum = 0;
for (var dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
var dyR = dyRIndex + startDyR;
if (dyR < 0 || dyR >= yHeight) {
continue;
}
var dyROffset = batchOffset + dyR * dy.strides[1];
var sourceFracRow = dyR * heightScale;
var sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
if (r !== sourceNearestRow) {
continue;
}
for (var dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
var dyC = dyCIndex + startDyC;
if (dyC < 0 || dyC >= yWidth) {
continue;
}
var dyCOffset = dyROffset + dyC * dy.strides[2];
var sourceFracCol = dyC * widthScale;
var sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) : Math.floor(sourceFracCol));
if (c === sourceNearestCol) {
accum += dyValues[dyCOffset + d];
}
}
}
output[colOffset + d] = accum;
}
}
}
}
return tf.tensor4d(output, x.shape, x.dtype);
};
MathBackendCPU2.prototype.localResponseNormalization4D = function(x, depthRadius, bias, alpha, beta) {
assertNotComplex(x, "localResponseNormalization4D");
var channels = x.shape[3];
var maxD = channels - 1;
var xValues = this.readSync(x.dataId);
var size = x.size;
var result = new Float32Array(size);
function sumAcrossChannels(offset2) {
var currentChannel = offset2 % channels;
var beginSumOffset = offset2 - currentChannel + Math.max(0, currentChannel - depthRadius);
var endSumOffset = offset2 - currentChannel + Math.min(currentChannel + depthRadius, maxD);
var sum2 = 0;
for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
var z = xValues[beginSumOffset];
sum2 += z * z;
}
return sum2;
}
for (var offset = 0; offset < size; offset++) {
var sum = sumAcrossChannels(offset);
var val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
result[offset] = val;
}
return tf.tensor4d(result, x.shape);
};
MathBackendCPU2.prototype.LRNGrad = function(dy, inputImage, outputImage, depthRadius, bias, alpha, beta) {
assertNotComplex(dy, "LRNGrad");
var channels = dy.shape[3];
var dyValues = this.readSync(dy.dataId);
var inputImageValues = this.readSync(inputImage.dataId);
var outputImageValues = this.readSync(outputImage.dataId);
var result = new Float32Array(dy.size);
var size = dy.size;
for (var offset = 0; offset < size; offset++) {
var currentChannel = offset % channels;
var depthBegin = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
var depthEnd = offset - currentChannel + Math.min(channels, currentChannel + depthRadius + 1);
var norm = 0;
for (var k = depthBegin; k < depthEnd; k++) {
norm += Math.pow(inputImageValues[k], 2);
}
norm = alpha * norm + bias;
for (var k = depthBegin; k < depthEnd; k++) {
var dyi = -2 * alpha * beta * inputImageValues[k] * outputImageValues[offset] / norm;
if (offset === k) {
dyi += Math.pow(norm, -beta);
}
dyi *= dyValues[offset];
result[k] += dyi;
}
}
return tf.tensor4d(result, dy.shape);
};
MathBackendCPU2.prototype.multinomial = function(logits, normalized, numSamples, seed) {
assertNotComplex(logits, "multinomial");
var probabilities = normalized ? logits : tf.softmax(logits);
var batchSize = probabilities.shape[0];
var numEvents = probabilities.shape[1];
var res = tf.zeros([batchSize, numSamples], "int32");
var resVals = this.readSync(res.dataId);
var probVals = this.readSync(probabilities.dataId);
for (var b = 0; b < batchSize; ++b) {
var offset = b * numEvents;
var cdf = new Float32Array(numEvents - 1);
cdf[0] = probVals[offset];
for (var event_1 = 1; event_1 < cdf.length; ++event_1) {
cdf[event_1] = cdf[event_1 - 1] + probVals[offset + event_1];
}
var random = seedrandom.alea(seed.toString());
var outOffset = b * numSamples;
for (var sampleId = 0; sampleId < numSamples; ++sampleId) {
var r = random();
resVals[outOffset + sampleId] = cdf.length;
for (var event_2 = 0; event_2 < cdf.length; event_2++) {
if (r < cdf[event_2]) {
resVals[outOffset + sampleId] = event_2;
break;
}
}
}
}
return res;
};
MathBackendCPU2.prototype.oneHot = function(indices, depth, onValue, offValue) {
assertNotComplex(indices, "oneHot");
var res = new Float32Array(indices.size * depth);
res.fill(offValue);
var indicesVal = this.readSync(indices.dataId);
for (var event_3 = 0; event_3 < indices.size; ++event_3) {
if (indicesVal[event_3] >= 0 && indicesVal[event_3] < depth) {
res[event_3 * depth + indicesVal[event_3]] = onValue;
}
}
return tf.tensor2d(res, [indices.size, depth], "int32");
};
MathBackendCPU2.prototype.nonMaxSuppression = function(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
assertNotComplex(boxes, "nonMaxSuppression");
var boxesVals = this.readSync(boxes.dataId);
var scoresVals = this.readSync(scores.dataId);
return nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
};
MathBackendCPU2.prototype.depthToSpace = function(x, blockSize, dataFormat) {
tf.util.assert(dataFormat === "NHWC", function() {
return "Only NHWC dataFormat supported on CPU for depthToSpace. Got " + dataFormat;
});
tf.util.assert(blockSize > 1, function() {
return "blockSize should be > 1 for depthToSpace, but was: " + blockSize;
});
var batchSize = x.shape[0];
var inputHeight = x.shape[1];
var inputWidth = x.shape[2];
var inputDepth = x.shape[3];
var outputHeight = inputHeight * blockSize;
var outputWidth = inputWidth * blockSize;
var outputDepth = inputDepth / (blockSize * blockSize);
var xValues = this.readSync(x.dataId);
var result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
var outputIdx = 0;
for (var b = 0; b < batchSize; ++b) {
for (var h = 0; h < outputHeight; ++h) {
var inH = Math.floor(h / blockSize);
var offsetH = h % blockSize;
for (var w = 0; w < outputWidth; ++w) {
var inW = Math.floor(w / blockSize);
var offsetW = w % blockSize;
var offsetD = (offsetH * blockSize + offsetW) * outputDepth;
for (var d = 0; d < outputDepth; ++d) {
var inD = d + offsetD;
var inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
result[outputIdx++] = xValues[inputIdx];
}
}
}
}
return tf.tensor4d(result, [batchSize, outputHeight, outputWidth, outputDepth]);
};
MathBackendCPU2.prototype.broadcastedBinaryOp = function(a, b, dtype, op) {
var newShape = tf.backend_util.assertAndGetBroadcastShape(a.shape, b.shape);
var result = tf.buffer(newShape, dtype);
var aVals = this.readSync(a.dataId);
var bVals = this.readSync(b.dataId);
var aBroadcastDims = tf.backend_util.getBroadcastDims(a.shape, newShape);
var bBroadcastDims = tf.backend_util.getBroadcastDims(b.shape, newShape);
var resVals = result.values;
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (var i = 0; i < resVals.length; ++i) {
resVals[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
}
} else {
var aBuf = this.bufferSync(a);
var bBuf = this.bufferSync(b);
var _loop_2 = function(i2) {
var loc = result.indexToLoc(i2);
var aLoc = loc.slice(-a.rank);
aBroadcastDims.forEach(function(d) {
return aLoc[d] = 0;
});
var aIndex = aBuf.locToIndex(aLoc);
var bLoc = loc.slice(-b.rank);
bBroadcastDims.forEach(function(d) {
return bLoc[d] = 0;
});
var bIndex = bBuf.locToIndex(bLoc);
resVals[i2] = op(aVals[aIndex], bVals[bIndex]);
};
for (var i = 0; i < resVals.length; ++i) {
_loop_2(i);
}
}
return result.toTensor();
};
MathBackendCPU2.prototype.split = function(x, sizeSplits, axis) {
return split(x, sizeSplits, axis);
};
MathBackendCPU2.prototype.dispose = function() {
};
MathBackendCPU2.prototype.floatPrecision = function() {
return 32;
};
MathBackendCPU2.prototype.epsilon = function() {
return _super.prototype.epsilon.call(this);
};
MathBackendCPU2.prototype.cropAndResize = function(images, boxes, boxIndex, cropSize, method, extrapolationValue) {
var _a = images.shape, batch = _a[0], imageHeight = _a[1], imageWidth = _a[2], numChannels = _a[3];
var numBoxes = boxes.shape[0];
var cropHeight = cropSize[0], cropWidth = cropSize[1];
var output = tf.buffer([numBoxes, cropHeight, cropWidth, numChannels], "float32");
var boxVals = this.readSync(boxes.dataId);
var boxIndVals = this.readSync(boxIndex.dataId);
var imageVals = this.readSync(images.dataId);
var inStride = images.strides;
var outStride = output.strides;
for (var b = 0; b < numBoxes; b++) {
var startInd = b * 4;
var y1 = boxVals[startInd];
var x1 = boxVals[startInd + 1];
var y2 = boxVals[startInd + 2];
var x2 = boxVals[startInd + 3];
var bInd = boxIndVals[b];
if (bInd >= batch) {
continue;
}
var heightScale = cropHeight > 1 ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : 0;
var widthScale = cropWidth > 1 ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
for (var y = 0; y < cropHeight; y++) {
var yInd = cropHeight > 1 ? y1 * (imageHeight - 1) + y * heightScale : 0.5 * (y1 + y2) * (imageHeight - 1);
if (yInd < 0 || yInd > imageHeight - 1) {
for (var x = 0; x < cropWidth; x++) {
for (var c = 0; c < numChannels; c++) {
var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
}
continue;
}
if (method === "bilinear") {
var topInd = Math.floor(yInd);
var bottomInd = Math.ceil(yInd);
var yLerp = yInd - topInd;
for (var x = 0; x < cropWidth; x++) {
var xInd = cropWidth > 1 ? x1 * (imageWidth - 1) + x * widthScale : 0.5 * (x1 + x2) * (imageWidth - 1);
if (xInd < 0 || xInd > imageWidth - 1) {
for (var c = 0; c < numChannels; c++) {
var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
continue;
}
var leftInd = Math.floor(xInd);
var rightInd = Math.ceil(xInd);
var xLerp = xInd - leftInd;
for (var c = 0; c < numChannels; c++) {
var ind = c + leftInd * inStride[2] + topInd * inStride[1] + bInd * inStride[0];
var topLeft = imageVals[ind];
ind = c + rightInd * inStride[2] + topInd * inStride[1] + bInd * inStride[0];
var topRight = imageVals[ind];
ind = c + leftInd * inStride[2] + bottomInd * inStride[1] + bInd * inStride[0];
var bottomLeft = imageVals[ind];
ind = c + rightInd * inStride[2] + bottomInd * inStride[1] + bInd * inStride[0];
var bottomRight = imageVals[ind];
var top_2 = topLeft + (topRight - topLeft) * xLerp;
var bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = top_2 + (bottom - top_2) * yLerp;
}
}
} else {
for (var x = 0; x < cropWidth; ++x) {
var xInd = cropWidth > 1 ? x1 * (imageWidth - 1) + x * widthScale : 0.5 * (x1 + x2) * (imageWidth - 1);
if (xInd < 0 || xInd > imageWidth - 1) {
for (var c = 0; c < numChannels; c++) {
var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
continue;
}
var closestX = Math.round(xInd);
var closestY = Math.round(yInd);
for (var c = 0; c < numChannels; c++) {
var inInd = c + closestX * inStride[2] + closestY * inStride[1] + bInd * inStride[0];
var outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[outInd] = imageVals[inInd];
}
}
}
}
}
return output.toTensor();
};
MathBackendCPU2.prototype.sparseToDense = function(sparseIndices, sparseValues, outputShape, defaultValue) {
var _a = tf.backend_util.calculateShapes(sparseValues, sparseIndices, outputShape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize;
var sumDupeIndices = false;
return this.scatter(sparseIndices, sparseValues, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices);
};
MathBackendCPU2.prototype.gatherND = function(x, indices) {
var indicesShape = indices.shape;
var sliceRank = indicesShape[indicesShape.length - 1];
var _a = tf.backend_util.prepareAndValidate(x, indices), resultShape = _a[0], numSlices = _a[1], sliceSize = _a[2], strides = _a[3];
if (numSlices === 0) {
return tf.tensor([], resultShape, x.dtype);
}
var buffer = new tf.TensorBuffer([numSlices, sliceSize], x.dtype);
var indicesData = this.readSync(indices.dataId);
var xData = this.readSync(x.dataId);
for (var i = 0; i < numSlices; i++) {
var index = [];
var flattenIndex = 0;
for (var j = 0; j < sliceRank; j++) {
var dim = indicesData[i * sliceRank + j];
flattenIndex += dim * strides[j];
index.push(dim);
}
if (flattenIndex < 0 || flattenIndex >= x.size / sliceSize) {
throw new Error("Invalid indices: " + index + " does not index into " + x.shape);
}
for (var k = 0; k < sliceSize; k++) {
buffer.values[i * sliceSize + k] = xData[flattenIndex * sliceSize + k];
}
}
return buffer.toTensor().reshape(resultShape);
};
MathBackendCPU2.prototype.scatterND = function(indices, updates, shape) {
var _a = tf.backend_util.calculateShapes(updates, indices, shape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize;
var defaultValue = tf.scalar(0);
var sumDupeIndices = true;
return this.scatter(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices);
};
MathBackendCPU2.prototype.onesLike = function(x) {
if (x.dtype === "string") {
throw new Error("onesLike is not supported for string tensors");
} else {
return tf.fill(x.shape, 1, x.dtype);
}
};
MathBackendCPU2.prototype.zerosLike = function(x) {
var values = tf.util.getArrayFromDType(x.dtype, tf.util.sizeFromShape(x.shape));
return this.makeOutput(values, x.shape, x.dtype);
};
MathBackendCPU2.prototype.linspace = function(start, stop, num) {
return tf.backend_util.linspaceImpl(start, stop, num);
};
MathBackendCPU2.prototype.scatter = function(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
var flattenShape = [outputSize / sliceSize, sliceSize];
var indicesData = this.readSync(indices.dataId);
var updatesData = this.readSync(updates.dataId);
if (outputSize === 0) {
return tf.tensor([], shape, updates.dtype);
}
var buffer = new tf.TensorBuffer(flattenShape, updates.dtype);
buffer.values.fill(this.readSync(defaultValue.dataId)[0]);
for (var i = 0; i < numUpdates; i++) {
var index = [];
var flattenIndex = 0;
for (var j = 0; j < sliceRank; j++) {
var dim = indicesData[i * sliceRank + j];
index.push(dim);
flattenIndex += dim * strides[j];
}
if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
throw new Error("Invalid indices: " + index + " does not index into " + shape);
}
for (var k = 0; k < sliceSize; k++) {
if (sumDupeIndices) {
buffer.values[flattenIndex * sliceSize + k] += updatesData[i * sliceSize + k];
} else {
buffer.values[flattenIndex * sliceSize + k] = updates.rank === 0 ? updatesData[0] : updatesData[i * sliceSize + k];
}
}
}
return buffer.toTensor().reshape(shape);
};
return MathBackendCPU2;
}(tf.KernelBackend);
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function simpleAbsImpl(vals) {
var resultValues = new Float32Array(vals.length);
for (var i = 0; i < vals.length; ++i) {
resultValues[i] = Math.abs(vals[i]);
}
return resultValues;
}
var abs = function(args) {
var x = args.inputs.x;
var cpuBackend = args.backend;
var resultValues = new Float32Array(tf.util.sizeFromShape(x.shape));
if (x.dtype !== "complex64") {
var values = cpuBackend.data.get(x.dataId).values;
resultValues = simpleAbsImpl(values);
} else {
var complexVals = cpuBackend.data.get(x.dataId);
var real2 = complexVals.complexTensorInfos.real;
var imag2 = complexVals.complexTensorInfos.imag;
var realVals = cpuBackend.data.get(real2.dataId).values;
var imagVals = cpuBackend.data.get(imag2.dataId).values;
for (var i = 0; i < realVals.length; i++) {
var real_1 = realVals[i];
var imag_1 = imagVals[i];
resultValues[i] = Math.hypot(real_1, imag_1);
}
}
return cpuBackend.makeOutput(resultValues, x.shape, "float32");
};
var absConfig = {
kernelName: tf.Abs,
backendName: "cpu",
kernelFunc: abs
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createSimpleBinaryKernelImpl(op) {
return function(aShape, bShape, aVals, bVals, dtype) {
var newShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
var resultRank = newShape.length;
var resultStrides = tf.util.computeStrides(newShape);
var resultSize = tf.util.sizeFromShape(newShape);
var result = tf.util.getTypedArrayFromDType(dtype, resultSize);
var aRank = aShape.length;
var bRank = bShape.length;
var aStrides = tf.util.computeStrides(aShape);
var bStrides = tf.util.computeStrides(bShape);
var aBroadcastDims = tf.backend_util.getBroadcastDims(aShape, newShape);
var bBroadcastDims = tf.backend_util.getBroadcastDims(bShape, newShape);
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (var i = 0; i < result.length; ++i) {
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
}
} else {
var _loop_1 = function(i2) {
var loc = tf.util.indexToLoc(i2, resultRank, resultStrides);
var aLoc = loc.slice(-aRank);
aBroadcastDims.forEach(function(d) {
return aLoc[d] = 0;
});
var aIndex = tf.util.locToIndex(aLoc, aRank, aStrides);
var bLoc = loc.slice(-bRank);
bBroadcastDims.forEach(function(d) {
return bLoc[d] = 0;
});
var bIndex = tf.util.locToIndex(bLoc, bRank, bStrides);
result[i2] = op(aVals[aIndex], bVals[bIndex]);
};
for (var i = 0; i < result.length; ++i) {
_loop_1(i);
}
}
return [result, newShape];
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function complex(args) {
var inputs = args.inputs, backend = args.backend;
var real2 = inputs.real, imag2 = inputs.imag;
var realVals = backend.data.get(real2.dataId).values;
var imagVals = backend.data.get(imag2.dataId).values;
var complexInfo = backend.makeTensorInfo(real2.shape, "complex64");
var complex2 = backend.data.get(complexInfo.dataId);
complex2.complexTensorInfos = {
real: backend.makeTensorInfo(real2.shape, "float32", realVals),
imag: backend.makeTensorInfo(imag2.shape, "float32", imagVals)
};
return complexInfo;
}
var complexConfig = {
kernelName: tf.Complex,
backendName: "cpu",
kernelFunc: complex
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function identity(args) {
var inputs = args.inputs, backend = args.backend;
var x = inputs.x;
backend.incRef(x.dataId);
return {dataId: x.dataId, shape: x.shape, dtype: x.dtype};
}
var identityConfig = {
kernelName: tf.Identity,
backendName: "cpu",
kernelFunc: identity
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function real(args) {
var inputs = args.inputs, backend = args.backend;
var input = inputs.input;
var real2 = backend.data.get(input.dataId).complexTensorInfos.real;
var realVal = backend.data.get(real2.dataId).values;
return backend.makeTensorInfo(real2.shape, real2.dtype, realVal);
}
var realConfig = {
kernelName: tf.Real,
backendName: "cpu",
kernelFunc: real
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cast(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var dtype = attrs.dtype;
if (dtype === "complex64") {
if (x.dtype === "complex64") {
return identity({inputs: {x}, backend});
}
var zerosTensor = tf.zeros(x.shape);
var floatX = cast({inputs: {x}, backend, attrs: {dtype: "float32"}});
var result = complex({inputs: {real: floatX, imag: zerosTensor}, backend});
zerosTensor.dispose();
backend.disposeIntermediateTensorInfo(floatX);
return result;
}
if (x.dtype === "complex64") {
var realPart = real({inputs: {input: x}, backend});
var result = cast({inputs: {x: realPart}, backend, attrs: {dtype}});
backend.disposeIntermediateTensorInfo(realPart);
return result;
}
if (!tf.util.hasEncodingLoss(x.dtype, dtype)) {
var result = identity({inputs: {x}, backend});
return {dataId: result.dataId, shape: result.shape, dtype};
}
if (dtype === "int32") {
var values = backend.data.get(x.dataId).values;
var resultValues = Int32Array.from(values);
return backend.makeTensorInfo(x.shape, "int32", resultValues);
}
if (dtype === "bool") {
var xVals = backend.data.get(x.dataId).values;
var zero = tf.util.toTypedArray([0], x.dtype);
var _a = createSimpleBinaryKernelImpl(function(a, b) {
return a !== b ? 1 : 0;
})(x.shape, [], xVals, zero, "bool"), resultData = _a[0], resultShape = _a[1];
return backend.makeTensorInfo(resultShape, "bool", resultData);
}
throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype);
}
var castConfig = {
kernelName: tf.Cast,
backendName: "cpu",
kernelFunc: cast
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function binaryKernelFunc(name, simpleImpl, complexImpl, dtype) {
if (complexImpl == null) {
return function(_a) {
var inputs = _a.inputs, backend = _a.backend;
var _b = inputs, a = _b.a, b = _b.b;
var cpuBackend = backend;
assertNotComplex([a, b], name);
var aVals = cpuBackend.data.get(a.dataId).values;
var bVals = cpuBackend.data.get(b.dataId).values;
var $dtype = dtype || a.dtype;
var _c = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype), resultData = _c[0], resultShape = _c[1];
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
};
}
return function(_a) {
var inputs = _a.inputs, backend = _a.backend;
var _b = inputs, a = _b.a, b = _b.b;
var cpuBackend = backend;
if (a.dtype === "complex64" || b.dtype === "complex64") {
var $aComplex = cast({inputs: {x: a}, backend: cpuBackend, attrs: {dtype: "complex64"}});
var $aComplexVals = cpuBackend.data.get($aComplex.dataId);
var aReal = $aComplexVals.complexTensorInfos.real;
var aImag = $aComplexVals.complexTensorInfos.imag;
var aRealVals = cpuBackend.data.get(aReal.dataId).values;
var aImagVals = cpuBackend.data.get(aImag.dataId).values;
var $bComplex = cast({inputs: {x: b}, backend: cpuBackend, attrs: {dtype: "complex64"}});
var $bComplexVals = cpuBackend.data.get($bComplex.dataId);
var bReal = $bComplexVals.complexTensorInfos.real;
var bImag = $bComplexVals.complexTensorInfos.imag;
var bRealVals = cpuBackend.data.get(bReal.dataId).values;
var bImagVals = cpuBackend.data.get(bImag.dataId).values;
var _c = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals), resultRealData = _c[0], resultImagData = _c[1], resultShape = _c[2];
var resultReal = cpuBackend.makeTensorInfo(resultShape, "float32", resultRealData);
var resultImag = cpuBackend.makeTensorInfo(resultShape, "float32", resultImagData);
var result = complex({inputs: {real: resultReal, imag: resultImag}, backend: cpuBackend});
cpuBackend.disposeIntermediateTensorInfo($aComplex);
cpuBackend.disposeIntermediateTensorInfo($bComplex);
cpuBackend.disposeIntermediateTensorInfo(resultReal);
cpuBackend.disposeIntermediateTensorInfo(resultImag);
return result;
} else {
var aVals = cpuBackend.data.get(a.dataId).values;
var bVals = cpuBackend.data.get(b.dataId).values;
var $dtype = dtype || a.dtype;
var _d = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype), resultData = _d[0], resultShape = _d[1];
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
}
};
}
function createComplexBinaryKernelImpl(op) {
return function(aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) {
var resultShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
var resultSize = tf.util.sizeFromShape(resultShape);
var resultRank = resultShape.length;
var resultStrides = tf.util.computeStrides(resultShape);
var resultRealVals = tf.util.getTypedArrayFromDType("float32", resultSize);
var resultImagVals = tf.util.getTypedArrayFromDType("float32", resultSize);
var aBroadcastDims = tf.backend_util.getBroadcastDims(aShape, resultShape);
var bBroadcastDims = tf.backend_util.getBroadcastDims(bShape, resultShape);
var aVals = tf.backend_util.mergeRealAndImagArrays(aRealVals, aImagVals);
var bVals = tf.backend_util.mergeRealAndImagArrays(bRealVals, bImagVals);
var aRank = aShape.length;
var aStrides = tf.util.computeStrides(aShape);
var bRank = bShape.length;
var bStrides = tf.util.computeStrides(bShape);
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (var i = 0; i < resultRealVals.length; i++) {
var aIdx = i % aVals.length;
var bIdx = i % bVals.length;
var result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
resultRealVals[i] = result.real;
resultImagVals[i] = result.imag;
}
} else {
var _loop_1 = function(i2) {
var loc = tf.util.indexToLoc(i2, resultRank, resultStrides);
var aLoc = loc.slice(-aRank);
aBroadcastDims.forEach(function(d) {
return aLoc[d] = 0;
});
var aIndex = tf.util.locToIndex(aLoc, aRank, aStrides);
var bLoc = loc.slice(-bRank);
bBroadcastDims.forEach(function(d) {
return bLoc[d] = 0;
});
var bIndex = tf.util.locToIndex(bLoc, bRank, bStrides);
var opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
resultRealVals[i2] = opResult.real;
resultImagVals[i2] = opResult.imag;
};
for (var i = 0; i < resultRealVals.length; i++) {
_loop_1(i);
}
}
return [resultRealVals, resultImagVals, resultShape];
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var addImpl = createSimpleBinaryKernelImpl(function(a, b) {
return a + b;
});
var addComplexImpl = createComplexBinaryKernelImpl(function(aReal, aImag, bReal, bImag) {
return {real: aReal + bReal, imag: aImag + bImag};
});
var add = binaryKernelFunc(tf.Add, addImpl, addComplexImpl);
var addConfig = {
kernelName: tf.Add,
backendName: "cpu",
kernelFunc: add
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createSimpleUnaryImpl(op) {
return function(values, dtype, attrs) {
var newValues = tf.util.getTypedArrayFromDType(dtype, values.length);
for (var i = 0; i < values.length; ++i) {
newValues[i] = op(values[i], attrs);
}
return newValues;
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unaryKernelFunc(name, op, dtype) {
return function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var x = inputs.x;
assertNotComplex(x, name);
if (x.dtype === "string" || dtype === "string") {
throw new Error("unaryKernelFunc does not support string input/output");
}
var cpuBackend = backend;
var values = cpuBackend.data.get(x.dataId).values;
var xSize = tf.util.sizeFromShape(x.shape);
var $dtype = dtype || x.dtype;
var newValues = tf.util.getArrayFromDType($dtype, xSize);
for (var i = 0; i < xSize; ++i) {
newValues[i] = op(values[i], attrs);
}
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
};
}
function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
return function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var x = inputs.x;
assertNotComplex(x, name);
if (x.dtype === "string" || dtype === "string") {
throw new Error("unaryKernelFunc does not support string input/output");
}
var cpuBackend = backend;
var values = cpuBackend.data.get(x.dataId).values;
var $dtype = dtype || x.dtype;
var newValues = unaryImpl(values, $dtype, attrs);
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ceilImpl = createSimpleUnaryImpl(function(xi) {
return Math.ceil(xi);
});
var ceil = unaryKernelFuncFromImpl(tf.Ceil, ceilImpl);
var ceilConfig = {
kernelName: tf.Ceil,
backendName: "cpu",
kernelFunc: ceil
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expImpl = createSimpleUnaryImpl(function(xi) {
return Math.exp(xi);
});
var exp = unaryKernelFuncFromImpl(tf.Exp, expImpl);
var expConfig = {
kernelName: tf.Exp,
backendName: "cpu",
kernelFunc: exp
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expm1Impl = createSimpleUnaryImpl(function(xi) {
return Math.expm1(xi);
});
var expm1 = unaryKernelFuncFromImpl(tf.Expm1, expm1Impl);
var expm1Config = {
kernelName: tf.Expm1,
backendName: "cpu",
kernelFunc: expm1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var floorImpl = createSimpleUnaryImpl(function(xi) {
return Math.floor(xi);
});
var floor = unaryKernelFuncFromImpl(tf.Floor, floorImpl);
var floorConfig = {
kernelName: tf.Floor,
backendName: "cpu",
kernelFunc: floor
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logImpl = createSimpleUnaryImpl(function(xi) {
return Math.log(xi);
});
var log = unaryKernelFuncFromImpl(tf.Log, logImpl);
var logConfig = {
kernelName: tf.Log,
backendName: "cpu",
kernelFunc: log
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxImpl(aVals, reduceSize, outShape, dtype) {
var vals = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(outShape));
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var max = aVals[offset];
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
if (value > max) {
max = value;
}
}
vals[i] = max;
}
return vals;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var multiplyImpl = createSimpleBinaryKernelImpl(function(aValue, bValue) {
return aValue * bValue;
});
var multiplyComplexImpl = createComplexBinaryKernelImpl(function(aReal, aImag, bReal, bImag) {
return {
real: aReal * bReal - aImag * bImag,
imag: aReal * bImag + aImag * bReal
};
});
var multiply = binaryKernelFunc(tf.Multiply, multiplyImpl, multiplyComplexImpl);
var multiplyConfig = {
kernelName: tf.Multiply,
backendName: "cpu",
kernelFunc: multiply
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var notEqualImpl = createSimpleBinaryKernelImpl(function(a, b) {
return a !== b ? 1 : 0;
});
var notEqual = binaryKernelFunc(tf.NotEqual, notEqualImpl, null, "bool");
var notEqualConfig = {
kernelName: tf.NotEqual,
backendName: "cpu",
kernelFunc: notEqual
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var rsqrtImpl = createSimpleUnaryImpl(function(xi) {
return 1 / Math.sqrt(xi);
});
var rsqrt = unaryKernelFuncFromImpl(tf.Rsqrt, rsqrtImpl);
var rsqrtConfig = {
kernelName: tf.Rsqrt,
backendName: "cpu",
kernelFunc: rsqrt
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sliceImpl(vals, begin, size, shape, dtype) {
var isContinous = tf.slice_util.isSliceContinous(shape, begin, size);
var length = tf.util.sizeFromShape(size);
var xStrides = tf.util.computeStrides(shape);
if (isContinous) {
var flatOffset = tf.slice_util.computeFlatOffset(begin, xStrides);
return vals.subarray(flatOffset, flatOffset + length);
}
var outVals = tf.util.getTypedArrayFromDType(dtype, length);
for (var i = 0; i < length; ++i) {
var rank = size.length;
var strides = tf.util.computeStrides(size);
var loc = tf.util.indexToLoc(i, rank, strides);
var xLoc = loc.map(function(idx, j) {
return idx + begin[j];
});
var xIndex = tf.util.locToIndex(xLoc, shape.length, xStrides);
outVals[i] = vals[xIndex];
}
return outVals;
}
function slice(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var begin = attrs.begin, size = attrs.size;
assertNotComplex(x, "slice");
var _a = tf.slice_util.parseSliceParams(x, begin, size), $begin = _a[0], $size = _a[1];
tf.slice_util.assertParamsValid(x, $begin, $size);
var vals = backend.data.get(x.dataId).values;
var outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
return backend.makeTensorInfo($size, x.dtype, outVals);
}
var sliceConfig = {
kernelName: tf.Slice,
backendName: "cpu",
kernelFunc: slice
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var squaredDifferenceImpl = createSimpleBinaryKernelImpl(function(a, b) {
var diff = a - b;
return diff * diff;
});
var squaredDifference = binaryKernelFunc(tf.SquaredDifference, squaredDifferenceImpl);
var squaredDifferenceConfig = {
kernelName: tf.SquaredDifference,
backendName: "cpu",
kernelFunc: squaredDifference
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var subImpl = createSimpleBinaryKernelImpl(function(aValue, bValue) {
return aValue - bValue;
});
var subComplexImpl = createComplexBinaryKernelImpl(function(aReal, aImag, bReal, bImag) {
return {real: aReal - bReal, imag: aImag - bImag};
});
var sub = binaryKernelFunc(tf.Sub, subImpl, subComplexImpl);
var subConfig = {
kernelName: tf.Sub,
backendName: "cpu",
kernelFunc: sub
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transposeImpl(xVals, xShape, dtype, perm, newShape) {
var xRank = xShape.length;
var xSize = tf.util.sizeFromShape(xShape);
var xStrides = tf.util.computeStrides(xShape);
var newStrides = tf.util.computeStrides(newShape);
var result = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(newShape));
for (var i = 0; i < xSize; ++i) {
var loc = tf.util.indexToLoc(i, xRank, xStrides);
var newLoc = new Array(loc.length);
for (var i_1 = 0; i_1 < newLoc.length; i_1++) {
newLoc[i_1] = loc[perm[i_1]];
}
var newIndex = tf.util.locToIndex(newLoc, xRank, newStrides);
result[newIndex] = xVals[i];
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function uniqueImpl(values, axis, shape, dtype) {
var $axis = tf.util.parseAxisParam(axis, shape)[0];
var newShape = [1, shape[0], 1];
for (var i = 0; i < $axis; i++) {
newShape[0] *= shape[i];
}
newShape[1] = shape[$axis];
for (var i = $axis + 1; i < shape.length; i++) {
newShape[2] *= shape[i];
}
var uniqueElements = {};
var indices = new Int32Array(shape[$axis]);
var inputBuffer = new tf.TensorBuffer(newShape, dtype, values);
var uniqueIndices = [];
var is1DTensor = newShape[0] === 1 && newShape[2] === 1;
for (var i = 0; i < shape[$axis]; i++) {
var element = void 0;
if (is1DTensor) {
element = values[i].toString();
} else {
var axisValues = [];
for (var m = 0; m < newShape[0]; m++) {
for (var n = 0; n < newShape[2]; n++) {
axisValues.push(inputBuffer.get(m, i, n));
}
}
element = axisValues.join(",");
}
if (uniqueElements[element] !== void 0) {
indices[i] = uniqueElements[element];
} else {
var uniqueIndex = Object.keys(uniqueElements).length;
uniqueElements[element] = uniqueIndex;
indices[i] = uniqueIndex;
uniqueIndices.push(i);
}
}
var outputTmpShape = newShape.slice();
outputTmpShape[1] = Object.keys(uniqueElements).length;
var outputBuffer = new tf.TensorBuffer(outputTmpShape, dtype);
uniqueIndices.forEach(function(uniqueElementIndex, i2) {
for (var m2 = 0; m2 < newShape[0]; m2++) {
for (var n2 = 0; n2 < newShape[2]; n2++) {
outputBuffer.set(inputBuffer.get(m2, uniqueElementIndex, n2), m2, i2, n2);
}
}
});
var outputShape = shape.slice();
outputShape[$axis] = outputTmpShape[1];
return {
outputValues: outputBuffer.values,
outputShape,
indices
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var shared = {
__proto__: null,
simpleAbsImpl,
addImpl,
ceilImpl,
expImpl,
expm1Impl,
floorImpl,
logImpl,
maxImpl,
multiplyImpl,
notEqualImpl,
rsqrtImpl,
sliceImpl,
squaredDifferenceImpl,
subImpl,
transposeImpl,
uniqueImpl
};
/** @license See the LICENSE file. */
var version = "2.7.0";
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
tf.registerBackend("cpu", function() {
return new MathBackendCPU();
}, 1);
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var elu = unaryKernelFunc(tf.Elu, function(xi) {
return xi >= 0 ? xi : Math.exp(xi) - 1;
});
var eluConfig = {
kernelName: tf.Elu,
backendName: "cpu",
kernelFunc: elu
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var preluImpl = createSimpleBinaryKernelImpl(function(xValue, aValue) {
return xValue < 0 ? aValue * xValue : xValue;
});
function prelu(args) {
var inputs = args.inputs, backend = args.backend;
var x = inputs.x, alpha = inputs.alpha;
assertNotComplex([x, alpha], "prelu");
var aVals = backend.data.get(x.dataId).values;
var bVals = backend.data.get(alpha.dataId).values;
var _a = preluImpl(x.shape, alpha.shape, aVals, bVals, x.dtype), resultData = _a[0], resultShape = _a[1];
return backend.makeTensorInfo(resultShape, x.dtype, resultData);
}
var preluConfig = {
kernelName: tf.Prelu,
backendName: "cpu",
kernelFunc: prelu
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var relu = unaryKernelFunc(tf.Relu, function(xi) {
return Math.max(0, xi);
});
var reluConfig = {
kernelName: tf.Relu,
backendName: "cpu",
kernelFunc: relu
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var relu6 = unaryKernelFunc(tf.Relu6, function(xi) {
return Math.min(Math.max(0, xi), 6);
});
var relu6Config = {
kernelName: tf.Relu6,
backendName: "cpu",
kernelFunc: relu6
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function applyActivation(backend, x, activation, preluActivationWeights) {
if (activation === "linear") {
return identity({inputs: {x}, backend});
} else if (activation === "relu") {
return relu({inputs: {x}, backend});
} else if (activation === "elu") {
return elu({inputs: {x}, backend});
} else if (activation === "relu6") {
return relu6({inputs: {x}, backend});
} else if (activation === "prelu") {
return prelu({inputs: {x, alpha: preluActivationWeights}, backend});
}
throw new Error("Activation " + activation + " has not been implemented for the CPU backend.");
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reshape(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var shape = attrs.shape;
var xSize = tf.util.sizeFromShape(x.shape);
var $shape = tf.util.inferFromImplicitShape(shape, xSize);
var $xSize = tf.util.sizeFromShape($shape);
tf.util.assert(xSize === $xSize, function() {
return "The new shape (" + $shape + ") has " + $xSize + " elements and the old " + ("shape (" + x.shape + ") has " + xSize + " elements. The new shape and old ") + "shape must have the same number of elements.";
});
backend.incRef(x.dataId);
var xData = backend.data.get(x.dataId);
if (xData.complexTensorInfos != null) {
var real2 = xData.complexTensorInfos.real;
var imag2 = xData.complexTensorInfos.imag;
real2.shape = $shape;
imag2.shape = $shape;
}
return {dataId: x.dataId, shape: $shape, dtype: x.dtype};
}
var reshapeConfig = {
kernelName: tf.Reshape,
backendName: "cpu",
kernelFunc: reshape
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function batchMatMul(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var a = inputs.a, b = inputs.b;
var transposeA = attrs.transposeA, transposeB = attrs.transposeB;
assertNotComplex([a, b], "matMul");
var aRank = a.shape.length;
var bRank = b.shape.length;
var innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
var innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
var outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
var outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
var outerDimsA = a.shape.slice(0, -2);
var outerDimsB = b.shape.slice(0, -2);
var batchDimA = tf.util.sizeFromShape(outerDimsA);
var batchDimB = tf.util.sizeFromShape(outerDimsB);
var batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1;
tf.util.assert(aRank >= 2 && bRank >= 2 && batchDimsCompatible, function() {
return "Error in matMul: the input batch dimensions must either be the same or at least one input batch dimension must be 1. Got input " + ("batch dimensions of (" + outerDimsA + ") and (" + outerDimsB + ").");
});
var outShapeOuterDims = batchDimA > batchDimB ? a.shape.slice(0, -2) : b.shape.slice(0, -2);
var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
tf.util.assert(innerShapeA === innerShapeB, function() {
return "Error in matMul: inner shapes (" + innerShapeA + ") and (" + (innerShapeB + ") of Tensors with shapes " + a.shape + " and ") + (b.shape + " and transposeA=" + transposeA) + (" and transposeB=" + transposeB + " must match.");
});
var a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA];
var b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB];
var a3d = reshape({inputs: {x: a}, backend, attrs: {shape: a3dShape}});
var b3d = reshape({inputs: {x: b}, backend, attrs: {shape: b3dShape}});
var sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
var leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
var rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
var batchDim = Math.max(batchDimA, batchDimB);
var a3dValues = backend.data.get(a3d.dataId).values;
var b3dValues = backend.data.get(b3d.dataId).values;
var a3dStrides = tf.util.computeStrides(a3d.shape);
var b3dStrides = tf.util.computeStrides(b3d.shape);
var _a = transposeA ? [a3dStrides[0], 1, a3dStrides[1]] : [a3dStrides[0], a3dStrides[1], 1], aBatch = _a[0], aOuterStep = _a[1], aInnerStep = _a[2];
var _b = transposeB ? [1, b3dStrides[1], b3dStrides[0]] : [b3dStrides[1], 1, b3dStrides[0]], bInnerStep = _b[0], bOuterStep = _b[1], bBatch = _b[2];
var size = leftDim * rightDim;
var result = tf.buffer([batchDim, leftDim, rightDim], a3d.dtype);
var resVals = result.values;
var blockSize = backend.blockSize;
for (var bi = 0; bi < batchDim; bi++) {
for (var i0 = 0; i0 < leftDim; i0 += blockSize) {
for (var j0 = 0; j0 < rightDim; j0 += blockSize) {
for (var k0 = 0; k0 < sharedDim; k0 += blockSize) {
var iBlock = Math.min(i0 + blockSize, leftDim);
var jBlock = Math.min(j0 + blockSize, rightDim);
var kBlock = Math.min(k0 + blockSize, sharedDim);
for (var i = i0; i < iBlock; i++) {
for (var j = j0; j < jBlock; j++) {
var sum = 0;
for (var k = k0; k < kBlock; k++) {
var batchOffsetA = Math.min(bi, batchDimA - 1) * aBatch;
var batchOffsetB = Math.min(bi, batchDimB - 1) * bBatch;
var aVal = a3dValues[batchOffsetA + i * aOuterStep + k * aInnerStep];
var bVal = b3dValues[k * bInnerStep + j * bOuterStep + batchOffsetB];
sum += aVal * bVal;
}
resVals[bi * size + (i * rightDim + j)] += sum;
}
}
}
}
}
}
backend.disposeIntermediateTensorInfo(a3d);
backend.disposeIntermediateTensorInfo(b3d);
return backend.makeTensorInfo(outShape, result.dtype, result.values);
}
var batchMatMulConfig = {
kernelName: tf.BatchMatMul,
backendName: "cpu",
kernelFunc: batchMatMul
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function _fusedMatMul(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var a = inputs.a, b = inputs.b, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights;
var transposeA = attrs.transposeA, transposeB = attrs.transposeB, activation = attrs.activation;
var current;
var addRes;
var activationRes;
var intermediates = [];
var matMulRes = batchMatMul({inputs: {a, b}, attrs: {transposeA, transposeB}, backend});
current = matMulRes;
if (bias) {
addRes = add({inputs: {a: current, b: bias}, backend});
intermediates.push(current);
current = addRes;
}
if (activation) {
activationRes = applyActivation(backend, current, activation, preluActivationWeights);
intermediates.push(current);
current = activationRes;
}
for (var _i2 = 0, intermediates_1 = intermediates; _i2 < intermediates_1.length; _i2++) {
var i = intermediates_1[_i2];
backend.disposeIntermediateTensorInfo(i);
}
return current;
}
var _fusedMatMulConfig = {
kernelName: tf._FusedMatMul,
backendName: "cpu",
kernelFunc: _fusedMatMul
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acos = unaryKernelFunc(tf.Acos, function(xi) {
return Math.acos(xi);
});
var acosConfig = {
kernelName: tf.Acos,
backendName: "cpu",
kernelFunc: acos
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acosh = unaryKernelFunc(tf.Acosh, function(xi) {
return Math.acosh(xi);
});
var acoshConfig = {
kernelName: tf.Acosh,
backendName: "cpu",
kernelFunc: acosh
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asin = unaryKernelFunc(tf.Asin, function(xi) {
return Math.asin(xi);
});
var asinConfig = {
kernelName: tf.Asin,
backendName: "cpu",
kernelFunc: asin
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asinh = unaryKernelFunc(tf.Asinh, function(xi) {
return Math.asinh(xi);
});
var asinhConfig = {
kernelName: tf.Asinh,
backendName: "cpu",
kernelFunc: asinh
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atan = unaryKernelFunc(tf.Atan, function(xi) {
return Math.atan(xi);
});
var atanConfig = {
kernelName: tf.Atan,
backendName: "cpu",
kernelFunc: atan
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atanh = unaryKernelFunc(tf.Atanh, function(xi) {
return Math.atanh(xi);
});
var atanhConfig = {
kernelName: tf.Atanh,
backendName: "cpu",
kernelFunc: atanh
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function pool(xValues, xShape, dtype, strides, convInfo, poolType) {
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var initialValue = poolType === "max" ? Number.NEGATIVE_INFINITY : Number.POSITIVE_INFINITY;
var output = tf.buffer(convInfo.outShape, dtype);
var outputVals = output.values;
var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
var outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
var outputColStrides = convInfo.outShape[3];
for (var b = 0; b < convInfo.batchSize; ++b) {
var outputBatchOffset = b * outputBatchStrides;
var inputBatchOffset = b * strides[0];
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var xRCorner = yR * strideHeight - padTop;
var xRMin = Math.max(0, xRCorner);
var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
var outputRowOffset = outputBatchOffset + yR * outputRowStrides;
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var xCCorner = yC * strideWidth - padLeft;
var xCMin = Math.max(0, xCCorner);
var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
var minMaxValue = initialValue;
var avgValue = 0;
var count = 0;
for (var xR = xRMin; xR < xRMax; xR += dilationHeight) {
var xROffset = inputBatchOffset + xR * strides[1];
for (var xC = xCMin; xC < xCMax; xC += dilationWidth) {
var xCOffset = xROffset + xC * strides[2];
var pixel = xValues[xCOffset + d];
if (poolType === "max" && pixel > minMaxValue) {
minMaxValue = pixel;
} else if (poolType === "avg") {
avgValue += pixel;
count++;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
var outputOffset = outputRowOffset + yC * outputColStrides + d;
outputVals[outputOffset] = poolType === "avg" ? avgValue / count : minMaxValue;
}
}
}
}
return output;
}
function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions, includeBatchInIndex) {
if (flattenPositions === void 0) {
flattenPositions = false;
}
if (includeBatchInIndex === void 0) {
includeBatchInIndex = false;
}
var maxPositions = tf.buffer(convInfo.outShape, "int32");
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var xBuf = tf.buffer(xShape, dtype, xValues);
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var xRCorner = yR * strideHeight - padTop;
var xRMin = xRCorner;
while (xRMin < 0) {
xRMin += dilationHeight;
}
var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var xCCorner = yC * strideWidth - padLeft;
var xCMin = xCCorner;
while (xCMin < 0) {
xCMin += dilationWidth;
}
var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
var maxValue = Number.NEGATIVE_INFINITY;
var maxPosition = -1;
for (var xR = xRMin; xR < xRMax; xR += dilationHeight) {
var wR = xR - xRCorner;
for (var xC = xCMin; xC < xCMax; xC += dilationWidth) {
var wC = xC - xCCorner;
var pixel = xBuf.get(b, xR, xC, d);
if (pixel > maxValue) {
maxValue = pixel;
if (flattenPositions) {
maxPosition = includeBatchInIndex ? ((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) * convInfo.inChannels + d : (xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
} else {
maxPosition = wR * effectiveFilterWidth + wC;
}
}
}
}
maxPositions.set(maxPosition, b, yR, yC, d);
}
}
}
}
return maxPositions;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
assertNotComplex(x, "avgPool");
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
var dilations = 1;
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in avgPool: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
var res;
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
res = identity({inputs: {x}, backend});
} else {
var xValues = backend.data.get(x.dataId).values;
var strides_1 = tf.util.computeStrides(x.shape);
var buffer = pool(xValues, x.shape, x.dtype, strides_1, convInfo, "avg");
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
}
return res;
}
var avgPoolConfig = {
kernelName: tf.AvgPool,
backendName: "cpu",
kernelFunc: avgPool
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPoolBackprop(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var dy = inputs.dy, input = inputs.input;
var x = input;
assertNotComplex([dy, input], "avgPoolBackprop");
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad;
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1, pad);
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var dx = tf.buffer(x.shape, "float32");
var avgMultiplier = 1 / (filterHeight * filterWidth);
var dyData = backend.data.get(dy.dataId).values;
var dyBuf = tf.buffer(dy.shape, "float32", dyData);
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) {
for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) {
var dyRCorner = dxR - padTop;
var dyCCorner = dxC - padLeft;
var dotProd = 0;
for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
var dyR = (dyRCorner + wR) / strideHeight;
if (dyR < 0 || dyR >= convInfo.outHeight || Math.floor(dyR) !== dyR) {
continue;
}
for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
var dyC = (dyCCorner + wC) / strideWidth;
if (dyC < 0 || dyC >= convInfo.outWidth || Math.floor(dyC) !== dyC) {
continue;
}
var pixel = dyBuf.get(b, dyR, dyC, d);
dotProd += pixel;
}
}
dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var avgPoolBackpropConfig = {
kernelName: tf.AvgPoolBackprop,
backendName: "cpu",
kernelFunc: avgPoolBackprop
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function batchNorm(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, scale2 = inputs.scale, offset = inputs.offset, mean = inputs.mean, variance = inputs.variance;
tf.util.assert(mean.shape.length === variance.shape.length, function() {
return "Batch normalization gradient requires mean and variance to have equal ranks.";
});
tf.util.assert(offset == null || mean.shape.length === offset.shape.length, function() {
return "Batch normalization gradient requires mean and offset to have equal ranks.";
});
tf.util.assert(scale2 == null || mean.shape.length === scale2.shape.length, function() {
return "Batch normalization gradient requires mean and scale to have equal ranks.";
});
assertNotComplex([x, mean, variance, scale2, offset], "batchNorm");
var varianceEpsilon = attrs.varianceEpsilon;
if (varianceEpsilon == null) {
varianceEpsilon = 1e-3;
}
var xVals = backend.data.get(x.dataId).values;
var mVals = backend.data.get(mean.dataId).values;
var varVals = backend.data.get(variance.dataId).values;
var sVals = scale2 ? backend.data.get(scale2.dataId).values : new Float32Array([1]);
var offVals = offset ? backend.data.get(offset.dataId).values : new Float32Array([0]);
var outVals = new Float32Array(xVals.length);
var offValsLength = offVals.length;
var sValsLength = sVals.length;
var varValsLength = varVals.length;
var mValsLength = mVals.length;
var offi = 0;
var mi = 0;
var si = 0;
var vi = 0;
for (var i = 0; i < xVals.length; ++i) {
outVals[i] = offVals[offi++] + (xVals[i] - mVals[mi++]) * sVals[si++] / Math.sqrt(varVals[vi++] + varianceEpsilon);
if (offi >= offValsLength) {
offi = 0;
}
if (mi >= mValsLength) {
mi = 0;
}
if (si >= sValsLength) {
si = 0;
}
if (vi >= varValsLength) {
vi = 0;
}
}
return backend.makeTensorInfo(x.shape, x.dtype, outVals);
}
var batchNormConfig = {
kernelName: tf.FusedBatchNorm,
backendName: "cpu",
kernelFunc: batchNorm
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var clip = unaryKernelFunc(tf.ClipByValue, function(xi, attrs) {
var clipAttrs = attrs;
if (xi > clipAttrs.clipValueMax) {
return clipAttrs.clipValueMax;
}
return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
});
var clipConfig = {
kernelName: tf.ClipByValue,
backendName: "cpu",
kernelFunc: clip
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function imag(args) {
var inputs = args.inputs, backend = args.backend;
var input = inputs.input;
var imag2 = backend.data.get(input.dataId).complexTensorInfos.imag;
var imagVal = backend.data.get(imag2.dataId).values;
return backend.makeTensorInfo(imag2.shape, imag2.dtype, imagVal);
}
var imagConfig = {
kernelName: tf.Imag,
backendName: "cpu",
kernelFunc: imag
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function concat(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var axis = attrs.axis;
var $axis = tf.util.parseAxisParam(axis, inputs[0].shape)[0];
var outShape = tf.backend_util.computeOutShape(inputs.map(function(t) {
return t.shape;
}), $axis);
if (tf.util.sizeFromShape(outShape) === 0) {
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
}
var $inputs = inputs.filter(function(t) {
return tf.util.sizeFromShape(t.shape) > 0;
});
if ($inputs.length === 1) {
return $inputs[0];
}
var shapes = $inputs.map(function(t) {
return t.shape;
});
tf.backend_util.assertParamsConsistent(shapes, $axis);
if ($inputs[0].dtype === "complex64") {
var reals = $inputs.map(function(t) {
return real({inputs: {input: t}, backend});
});
var imags = $inputs.map(function(t) {
return imag({inputs: {input: t}, backend});
});
var realConcated = concat({inputs: reals, backend, attrs: {axis: $axis}});
var imagConcated = concat({inputs: imags, backend, attrs: {axis: $axis}});
var result = complex({inputs: {real: realConcated, imag: imagConcated}, backend});
reals.forEach(function(r) {
return backend.disposeIntermediateTensorInfo(r);
});
imags.forEach(function(i) {
return backend.disposeIntermediateTensorInfo(i);
});
backend.disposeIntermediateTensorInfo(realConcated);
backend.disposeIntermediateTensorInfo(imagConcated);
return result;
}
var inputs2D = $inputs.map(function(t) {
var innerSize = tf.util.sizeFromShape(t.shape.slice($axis));
var shape = [-1, innerSize];
return reshape({inputs: {x: t}, backend, attrs: {shape}});
});
outShape = tf.backend_util.computeOutShape(inputs2D.map(function(t) {
return t.shape;
}), 1);
var outVals = tf.util.getTypedArrayFromDType($inputs[0].dtype, tf.util.sizeFromShape(outShape));
if (inputs2D[0].shape[0] === 1) {
var offset_1 = 0;
inputs2D.forEach(function(t) {
var val = backend.data.get(t.dataId).values;
var size = tf.util.sizeFromShape(t.shape);
outVals.set(val, offset_1);
offset_1 += size;
});
} else {
var colOffset_1 = 0;
inputs2D.forEach(function(t) {
var tVals = backend.data.get(t.dataId).values;
var tIdx = 0;
for (var row = 0; row < t.shape[0]; ++row) {
var resIdx = row * outShape[1] + colOffset_1;
for (var col = 0; col < t.shape[1]; ++col) {
outVals[resIdx + col] = tVals[tIdx++];
}
}
colOffset_1 += t.shape[1];
});
}
var finalOutShape = tf.backend_util.computeOutShape($inputs.map(function(t) {
return t.shape;
}), $axis);
var outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
inputs2D.forEach(function(t) {
return backend.disposeIntermediateTensorInfo(t);
});
return outInfo;
}
var concatConfig = {
kernelName: tf.Concat,
backendName: "cpu",
kernelFunc: concat
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2D(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, filter = inputs.filter;
var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode;
assertNotComplex([x, filter], "conv2d");
var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat);
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat);
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var padLeft = convInfo.padInfo.left;
var padTop = convInfo.padInfo.top;
var isChannelsLast = convInfo.dataFormat === "channelsLast";
var y = new tf.TensorBuffer(convInfo.outShape, x.dtype);
var xStrides = tf.util.computeStrides(x.shape);
var filterStrides = tf.util.computeStrides(filter.shape);
var xBatchStride = xStrides[0];
var xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
var xColStride = isChannelsLast ? xStrides[2] : 1;
var xChannelStride = isChannelsLast ? 1 : xStrides[1];
var yBatchStride = y.strides[0];
var yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
var yColStride = isChannelsLast ? y.strides[2] : 1;
var yChannelStride = isChannelsLast ? 1 : y.strides[1];
var xVals = backend.data.get(x.dataId).values;
var wVals = backend.data.get(filter.dataId).values;
var yVals = y.values;
for (var b = 0; b < convInfo.batchSize; ++b) {
var xOffset1 = b * xBatchStride;
var yOffset1 = b * yBatchStride;
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var yOffset2 = yOffset1 + yR * yRowStride;
var xRCorner = yR * convInfo.strideHeight - padTop;
for (var wR = 0; wR < filterHeight; ++wR) {
var xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
var wOffset1 = wR * filterStrides[0];
var xOffset2 = xOffset1 + xR * xRowStride;
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var yOffset3 = yOffset2 + yC * yColStride;
var xCCorner = yC * convInfo.strideWidth - padLeft;
for (var wC = 0; wC < filterWidth; ++wC) {
var xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
var wOffset2 = wOffset1 + wC * filterStrides[1];
var xOffset3 = xOffset2 + xC * xColStride;
var wOffset3 = wOffset2;
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
var xVal = xVals[xOffset3 + d1 * xChannelStride];
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
yVals[yOffset3 + d2 * yChannelStride] += xVal * wVals[wOffset3 + d2];
}
wOffset3 += convInfo.outChannels;
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, yVals);
}
var conv2DConfig = {
kernelName: tf.Conv2D,
backendName: "cpu",
kernelFunc: conv2D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2DBackpropFilter(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, dy = inputs.dy;
var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode, filterShape = attrs.filterShape;
assertNotComplex([x, dy], "conv2dBackpropFilter");
var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat);
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filterShape, strides, 1, pad, dimRoundingMode, false, $dataFormat);
var strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth;
var isChannelsLast = convInfo.dataFormat === "channelsLast";
var dW = new tf.TensorBuffer(convInfo.filterShape, "float32");
var leftPad = convInfo.padInfo.left;
var topPad = convInfo.padInfo.top;
var xVals = backend.data.get(x.dataId).values;
var dyVals = backend.data.get(dy.dataId).values;
var xBuf = new tf.TensorBuffer(x.shape, x.dtype, xVals);
var dyBuf = new tf.TensorBuffer(dy.shape, dy.dtype, dyVals);
for (var wR = 0; wR < filterHeight; ++wR) {
var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
for (var wC = 0; wC < filterWidth; ++wC) {
var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
var dotProd = 0;
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var yR = yRMin; yR < yRMax; ++yR) {
var xR = wR + yR * strideHeight - topPad;
for (var yC = yCMin; yC < yCMax; ++yC) {
var xC = wC + yC * strideWidth - leftPad;
if (isChannelsLast) {
dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
} else {
dotProd += xBuf.get(b, d1, xR, xC) * dyBuf.get(b, d2, yR, yC);
}
}
}
}
dW.set(dotProd, wR, wC, d1, d2);
}
}
}
}
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
}
var conv2DBackpropFilterConfig = {
kernelName: tf.Conv2DBackpropFilter,
backendName: "cpu",
kernelFunc: conv2DBackpropFilter
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2DBackpropInput(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var dy = inputs.dy, filter = inputs.filter;
var inputShape = attrs.inputShape, strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode;
assertNotComplex([dy, filter], "conv2dBackpropInput");
var filterStrides = tf.util.computeStrides(filter.shape);
var dyStrides = tf.util.computeStrides(dy.shape);
var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat);
var convInfo = tf.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, 1, pad, dimRoundingMode, false, $dataFormat);
var dx = new tf.TensorBuffer(convInfo.inShape, "float32");
var dxValues = dx.values;
var dyValues = backend.data.get(dy.dataId).values;
var fltValues = backend.data.get(filter.dataId).values;
var fltS0 = filterStrides[0], fltS1 = filterStrides[1], fltS2 = filterStrides[2];
var batchSize = convInfo.batchSize, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth;
$dataFormat = convInfo.dataFormat;
var topPad = filterHeight - 1 - convInfo.padInfo.top;
var leftPad = filterWidth - 1 - convInfo.padInfo.left;
var isChannelsLast = $dataFormat === "channelsLast";
var xBatchStride = dx.strides[0];
var xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
var xColStride = isChannelsLast ? dx.strides[2] : 1;
var xChannelStride = isChannelsLast ? 1 : dx.strides[1];
var yBatchStride = dyStrides[0];
var yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
var yColStride = isChannelsLast ? dyStrides[2] : 1;
var yChannelStride = isChannelsLast ? 1 : dyStrides[1];
for (var b = 0; b < batchSize; ++b) {
for (var d1 = 0; d1 < inChannels; ++d1) {
for (var xR = 0; xR < inHeight; ++xR) {
var xRCorner = xR - topPad;
var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
for (var xC = 0; xC < inWidth; ++xC) {
var xCCorner = xC - leftPad;
var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
var dotProd = 0;
for (var yR = xRMin; yR < yRMax; ++yR) {
var wR = yR * strideHeight - xRCorner;
for (var yC = xCMin; yC < yCMax; ++yC) {
var wC = yC * strideWidth - xCCorner;
var dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
var fltOffset = fltS0 * (filterHeight - 1 - wR) + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
for (var d2 = 0; d2 < outChannels; ++d2) {
var pixel = dyValues[dyOffset + yChannelStride * d2];
var weight = fltValues[fltOffset + d2];
dotProd += pixel * weight;
}
}
}
var dxOffset = xBatchStride * b + xRowStride * xR + xColStride * xC + xChannelStride * d1;
dxValues[dxOffset] = dotProd;
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var conv2DBackpropInputConfig = {
kernelName: tf.Conv2DBackpropInput,
backendName: "cpu",
kernelFunc: conv2DBackpropInput
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3D(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, filter = inputs.filter;
var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations;
assertNotComplex([x, filter], "conv3d");
var convInfo = tf.backend_util.computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
var filterDepth = convInfo.filterDepth, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, dilationDepth = convInfo.dilationDepth, dilationHeight = convInfo.dilationHeight, dilationWidth = convInfo.dilationWidth, padInfo = convInfo.padInfo;
var padFront = padInfo.front;
var padLeft = padInfo.left;
var padTop = padInfo.top;
var y = new tf.TensorBuffer(convInfo.outShape, x.dtype);
var xVals = backend.data.get(x.dataId).values;
var wVals = backend.data.get(filter.dataId).values;
var yVals = y.values;
var xStrides = tf.util.computeStrides(x.shape);
var filterStrides = tf.util.computeStrides(filter.shape);
for (var b = 0; b < convInfo.batchSize; ++b) {
var xOffset1 = b * xStrides[0];
var yOffset1 = b * y.strides[0];
for (var yF = 0; yF < convInfo.outDepth; ++yF) {
var yOffset2 = yOffset1 + yF * y.strides[1];
var xFCorner = yF * convInfo.strideDepth - padFront;
for (var wF = 0; wF < filterDepth; ++wF) {
var xF = xFCorner + wF * dilationDepth;
if (xF < 0 || xF >= convInfo.inDepth) {
continue;
}
var wOffset1 = wF * filterStrides[0];
var xOffset2 = xOffset1 + xF * xStrides[1];
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var yOffset3 = yOffset2 + yR * y.strides[2];
var xRCorner = yR * convInfo.strideHeight - padTop;
for (var wR = 0; wR < filterHeight; ++wR) {
var xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
var wOffset2 = wOffset1 + wR * filterStrides[1];
var xOffset3 = xOffset2 + xR * xStrides[2];
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var yOffset4 = yOffset3 + yC * convInfo.outChannels;
var xCCorner = yC * convInfo.strideWidth - padLeft;
for (var wC = 0; wC < filterWidth; ++wC) {
var xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
var wOffset3 = wOffset2 + wC * filterStrides[2];
var xOffset4 = xOffset3 + xC * convInfo.inChannels;
var wOffset4 = wOffset3;
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
var xVal = xVals[xOffset4 + d1];
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
}
wOffset4 += convInfo.outChannels;
}
}
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
}
var conv3DConfig = {
kernelName: tf.Conv3D,
backendName: "cpu",
kernelFunc: conv3D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3DBackpropFilterV2(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, dy = inputs.dy;
var strides = attrs.strides, pad = attrs.pad, filterShape = attrs.filterShape;
assertNotComplex([x, dy], "conv3dBackpropFilterV2");
var xStrides = tf.util.computeStrides(x.shape);
var dyStrides = tf.util.computeStrides(dy.shape);
var convInfo = tf.backend_util.computeConv3DInfo(x.shape, filterShape, strides, 1, pad);
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var dw = new tf.TensorBuffer(convInfo.filterShape, "float32");
var dwValues = dw.values;
var _a = dw.strides, dwS0 = _a[0], dwS1 = _a[1], dwS2 = _a[2], dwS3 = _a[3];
var dyValues = backend.data.get(dy.dataId).values;
var dyS0 = dyStrides[0], dyS1 = dyStrides[1], dyS2 = dyStrides[2], dyS3 = dyStrides[3];
var xValues = backend.data.get(x.dataId).values;
var xS0 = xStrides[0], xS1 = xStrides[1], xS2 = xStrides[2], xS3 = xStrides[3];
var frontPad = convInfo.padInfo.front;
var leftPad = convInfo.padInfo.left;
var topPad = convInfo.padInfo.top;
for (var wF = 0; wF < filterDepth; ++wF) {
var yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
var yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
var wOffset1 = wF * dwS0;
for (var wR = 0; wR < filterHeight; ++wR) {
var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
var wOffset2 = wR * dwS1 + wOffset1;
for (var wC = 0; wC < filterWidth; ++wC) {
var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
var wOffset3 = wC * dwS2 + wOffset2;
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
var wOffset4 = d1 * dwS3 + wOffset3;
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
var dotProd = 0;
for (var b = 0; b < convInfo.batchSize; ++b) {
var xOffset1 = b * xS0;
var yOffset1 = b * dyS0;
for (var yF = yFMin; yF < yFMax; ++yF) {
var xF = wF + yF * strideDepth - frontPad;
var xOffset2 = xF * xS1 + xOffset1;
var yOffset2 = yF * dyS1 + yOffset1;
for (var yR = yRMin; yR < yRMax; ++yR) {
var xR = wR + yR * strideHeight - topPad;
var xOffset3 = xR * xS2 + xOffset2;
var yOffset3 = yR * dyS2 + yOffset2;
for (var yC = yCMin; yC < yCMax; ++yC) {
var xC = wC + yC * strideWidth - leftPad;
var xOffset4 = xC * xS3 + xOffset3;
var yOffset4 = yC * dyS3 + yOffset3;
dotProd += xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
}
}
}
}
dwValues[wOffset4 + d2] = dotProd;
}
}
}
}
}
return backend.makeTensorInfo(dw.shape, dw.dtype, dw.values);
}
var conv3DBackpropFilterV2Config = {
kernelName: tf.Conv3DBackpropFilterV2,
backendName: "cpu",
kernelFunc: conv3DBackpropFilterV2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3DBackpropInputV2(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var dy = inputs.dy, filter = inputs.filter;
var pad = attrs.pad, strides = attrs.strides, inputShape = attrs.inputShape;
assertNotComplex([dy], "conv3dBackpropInputV2");
var dyStrides = tf.util.computeStrides(dy.shape);
var filterStrides = tf.util.computeStrides(filter.shape);
var convInfo = tf.backend_util.computeConv3DInfo(inputShape, filter.shape, strides, 1, pad);
var dx = new tf.TensorBuffer(convInfo.inShape, "float32");
var dxValues = dx.values;
var _a = dx.strides, dxS0 = _a[0], dxS1 = _a[1], dxS2 = _a[2], dxS3 = _a[3];
var dyValues = backend.data.get(dy.dataId).values;
var dyS0 = dyStrides[0], dyS1 = dyStrides[1], dyS2 = dyStrides[2], dyS3 = dyStrides[3];
var fltValues = backend.data.get(filter.dataId).values;
var fltS0 = filterStrides[0], fltS1 = filterStrides[1], fltS2 = filterStrides[2], fltS3 = filterStrides[3];
var batchSize = convInfo.batchSize, filterDepth = convInfo.filterDepth, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inDepth = convInfo.inDepth, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outDepth = convInfo.outDepth, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideDepth = convInfo.strideDepth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth;
var frontPad = filterDepth - 1 - convInfo.padInfo.front;
var topPad = filterHeight - 1 - convInfo.padInfo.top;
var leftPad = filterWidth - 1 - convInfo.padInfo.left;
for (var b = 0; b < batchSize; ++b) {
for (var d1 = 0; d1 < inChannels; ++d1) {
for (var xF = 0; xF < inDepth; ++xF) {
var xFCorner = xF - frontPad;
var xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
var yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
for (var xR = 0; xR < inHeight; ++xR) {
var xRCorner = xR - topPad;
var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
for (var xC = 0; xC < inWidth; ++xC) {
var xCCorner = xC - leftPad;
var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
var dotProd = 0;
for (var yF = xFMin; yF < yFMax; ++yF) {
var wF = yF * strideDepth - xFCorner;
for (var yR = xRMin; yR < yRMax; ++yR) {
var wR = yR * strideHeight - xRCorner;
for (var yC = xCMin; yC < yCMax; ++yC) {
var wC = yC * strideWidth - xCCorner;
var dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
var fltOffset = fltS0 * (filterDepth - 1 - wF) + fltS1 * (filterHeight - 1 - wR) + fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
for (var d2 = 0; d2 < outChannels; ++d2) {
var pixel = dyValues[dyOffset + d2];
var weight = fltValues[fltOffset + d2];
dotProd += pixel * weight;
}
}
}
}
dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] = dotProd;
}
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var conv3DBackpropInputV2Config = {
kernelName: tf.Conv3DBackpropInputV2,
backendName: "cpu",
kernelFunc: conv3DBackpropInputV2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cos = unaryKernelFunc(tf.Cos, function(xi) {
return Math.cos(xi);
});
var cosConfig = {
kernelName: tf.Cos,
backendName: "cpu",
kernelFunc: cos
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cosh = unaryKernelFunc(tf.Cosh, function(xi) {
return Math.cosh(xi);
});
var coshConfig = {
kernelName: tf.Cosh,
backendName: "cpu",
kernelFunc: cosh
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNative(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, filter = inputs.filter;
var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode;
assertNotComplex([x, filter], "depthwiseConv2DNative");
var xStrides = tf.util.computeStrides(x.shape);
var filterStrides = tf.util.computeStrides(filter.shape);
var $dilations = dilations;
if ($dilations == null) {
$dilations = [1, 1];
}
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), function() {
return "Error in depthwiseConv2d: Either strides or dilations must be " + ("1. Got strides " + strides + " and dilations '" + $dilations + "'");
});
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true);
var filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, dilationHeight = convInfo.dilationHeight, dilationWidth = convInfo.dilationWidth, padInfo = convInfo.padInfo;
var padLeft = padInfo.left;
var padTop = padInfo.top;
var chMul = convInfo.outChannels / convInfo.inChannels;
var y = new tf.TensorBuffer(convInfo.outShape, x.dtype);
var xVals = backend.data.get(x.dataId).values;
var wVals = backend.data.get(filter.dataId).values;
var yVals = y.values;
for (var b = 0; b < convInfo.batchSize; ++b) {
var xOffset1 = b * xStrides[0];
var yOffset1 = b * y.strides[0];
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var yOffset2 = yOffset1 + yR * y.strides[1];
var xRCorner = yR * convInfo.strideHeight - padLeft;
for (var wR = 0; wR < filterHeight; ++wR) {
var xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
var wOffset1 = wR * filterStrides[0];
var xOffset2 = xOffset1 + xR * xStrides[1];
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var yOffset3 = yOffset2 + yC * y.strides[2];
var xCCorner = yC * convInfo.strideWidth - padTop;
for (var wC = 0; wC < filterWidth; ++wC) {
var xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
var wOffset2 = wOffset1 + wC * filterStrides[1];
var xOffset3 = xOffset2 + xC * convInfo.inChannels;
var yOffset4 = yOffset3;
var wOffset3 = wOffset2;
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
var xVal = xVals[xOffset3 + d1];
for (var q = 0; q < chMul; ++q) {
yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
}
yOffset4 += chMul;
wOffset3 += chMul;
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
}
var depthwiseConv2dNativeConfig = {
kernelName: tf.DepthwiseConv2dNative,
backendName: "cpu",
kernelFunc: depthwiseConv2dNative
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropFilter(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, dy = inputs.dy;
var strides = attrs.strides, dilations = attrs.dilations, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, filterShape = attrs.filterShape;
assertNotComplex([x, dy], "depthwiseConv2dNativeBackpropFilter");
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true);
var strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth;
var dW = new tf.TensorBuffer(convInfo.filterShape, "float32");
var leftPad = convInfo.padInfo.left;
var topPad = convInfo.padInfo.top;
var chMul = convInfo.outChannels / convInfo.inChannels;
var xVals = backend.data.get(x.dataId).values;
var xBuf = new tf.TensorBuffer(x.shape, x.dtype, xVals);
var dyVals = backend.data.get(dy.dataId).values;
var dyBuf = new tf.TensorBuffer(dy.shape, dy.dtype, dyVals);
for (var wR = 0; wR < filterHeight; ++wR) {
var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
for (var wC = 0; wC < filterWidth; ++wC) {
var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
var d1 = Math.trunc(d2 / chMul);
var dm = d2 % chMul;
var dotProd = 0;
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var yR = yRMin; yR < yRMax; ++yR) {
var xR = wR + yR * strideHeight - topPad;
for (var yC = yCMin; yC < yCMax; ++yC) {
var xC = wC + yC * strideWidth - leftPad;
dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
}
}
}
dW.set(dotProd, wR, wC, d1, dm);
}
}
}
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
}
var depthwiseConv2dNativeBackpropFilterConfig = {
kernelName: tf.DepthwiseConv2dNativeBackpropFilter,
backendName: "cpu",
kernelFunc: depthwiseConv2dNativeBackpropFilter
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropInput(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var dy = inputs.dy, filter = inputs.filter;
var strides = attrs.strides, dilations = attrs.dilations, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, inputShape = attrs.inputShape;
assertNotComplex([dy, filter], "depthwiseConv2DNativeBackpropInput");
var dyStrides = tf.util.computeStrides(dy.shape);
var filterStrides = tf.util.computeStrides(filter.shape);
var convInfo = tf.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true);
var dx = new tf.TensorBuffer(convInfo.inShape, "float32");
var dxValues = dx.values;
var _a = dx.strides, dxS0 = _a[0], dxS1 = _a[1], dxS2 = _a[2];
var dyValues = backend.data.get(dy.dataId).values;
var dyS0 = dyStrides[0], dyS1 = dyStrides[1], dyS2 = dyStrides[2];
var fltValues = backend.data.get(filter.dataId).values;
var fltS0 = filterStrides[0], fltS1 = filterStrides[1], fltS2 = filterStrides[2];
var batchSize = convInfo.batchSize, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth;
var topPad = filterHeight - 1 - convInfo.padInfo.top;
var leftPad = filterWidth - 1 - convInfo.padInfo.left;
var chMul = outChannels / inChannels;
for (var b = 0; b < batchSize; ++b) {
for (var d1 = 0; d1 < inChannels; ++d1) {
for (var xR = 0; xR < inHeight; ++xR) {
var xRCorner = xR - topPad;
var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
for (var xC = 0; xC < inWidth; ++xC) {
var xCCorner = xC - leftPad;
var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
var dotProd = 0;
for (var yR = xRMin; yR < yRMax; ++yR) {
var wR = yR * strideHeight - xRCorner;
for (var yC = xCMin; yC < yCMax; ++yC) {
var wC = yC * strideWidth - xCCorner;
var dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
var fltOffset = fltS0 * (filterHeight - 1 - wR) + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
for (var dm = 0; dm < chMul; ++dm) {
var d2 = d1 * chMul + dm;
var pixel = dyValues[dyOffset + d2];
var weight = fltValues[fltOffset + dm];
dotProd += pixel * weight;
}
}
}
dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var depthwiseConv2dNativeBackpropInputConfig = {
kernelName: tf.DepthwiseConv2dNativeBackpropInput,
backendName: "cpu",
kernelFunc: depthwiseConv2dNativeBackpropInput
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var dilation2dConfig = {
kernelName: tf.Dilation2D,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
var _b = inputs, x = _b.x, filter = _b.filter;
var _c = attrs, strides = _c.strides, pad = _c.pad, dilations = _c.dilations;
var cpuBackend = backend;
var xVals = cpuBackend.data.get(x.dataId).values;
var xRank = x.shape.length;
var filterVals = cpuBackend.data.get(filter.dataId).values;
var filterRank = filter.shape.length;
var _d = tf.backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, "NHWC", dilations), batchSize = _d.batchSize, inHeight = _d.inHeight, inWidth = _d.inWidth, inChannels = _d.inChannels, outHeight = _d.outHeight, outWidth = _d.outWidth, padInfo = _d.padInfo, strideHeight = _d.strideHeight, strideWidth = _d.strideWidth, filterHeight = _d.filterHeight, filterWidth = _d.filterWidth, dilationHeight = _d.dilationHeight, dilationWidth = _d.dilationWidth, outShape = _d.outShape;
var outSize = tf.util.sizeFromShape(outShape);
var outRank = outShape.length;
var outputVals = tf.util.getArrayFromDType(x.dtype, outSize);
for (var b = 0; b < batchSize; ++b) {
for (var hOut = 0; hOut < outHeight; ++hOut) {
var hBeg = hOut * strideHeight - padInfo.top;
for (var wOut = 0; wOut < outWidth; ++wOut) {
var wBeg = wOut * strideWidth - padInfo.left;
for (var d = 0; d < inChannels; ++d) {
var curVal = Number.MIN_SAFE_INTEGER;
for (var h = 0; h < filterHeight; ++h) {
var hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (var w = 0; w < filterWidth; ++w) {
var wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
var xIndex = tf.util.locToIndex([b, hIn, wIn, d], xRank, tf.util.computeStrides(x.shape));
var filterIndex = tf.util.locToIndex([h, w, d], filterRank, tf.util.computeStrides(filter.shape));
var val = xVals[xIndex] + filterVals[filterIndex];
if (val > curVal) {
curVal = val;
}
}
}
}
}
var outputIndex = tf.util.locToIndex([b, hOut, wOut, d], outRank, tf.util.computeStrides(outShape));
outputVals[outputIndex] = curVal;
}
}
}
}
var dataId = cpuBackend.write(tf.util.toTypedArray(outputVals, x.dtype), outShape, x.dtype);
return {dataId, shape: outShape, dtype: x.dtype};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var dilation2dBackpropFilterConfig = {
kernelName: tf.Dilation2DBackpropFilter,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
var _b = inputs, x = _b.x, filter = _b.filter, dy = _b.dy;
var _c = attrs, strides = _c.strides, pad = _c.pad, dilations = _c.dilations;
var cpuBackend = backend;
var $x = tf.util.toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
var $filter = tf.util.toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
var _d = tf.backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, "NHWC", dilations), batchSize = _d.batchSize, inHeight = _d.inHeight, inWidth = _d.inWidth, inChannels = _d.inChannels, outHeight = _d.outHeight, outWidth = _d.outWidth, padInfo = _d.padInfo, strideHeight = _d.strideHeight, strideWidth = _d.strideWidth, filterHeight = _d.filterHeight, filterWidth = _d.filterWidth, dilationHeight = _d.dilationHeight, dilationWidth = _d.dilationWidth, outShape = _d.outShape;
tf.util.assert(dy.rank === outShape.length, function() {
return "Error in " + tf.Dilation2DBackpropFilter + ", dy " + ("must have the same rank as output " + outShape.length + ", but got ") + ("" + dy.rank);
});
var $dy = tf.util.toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
var gradients = tf.util.makeZerosNestedTypedArray(filter.shape, filter.dtype);
for (var b = 0; b < batchSize; ++b) {
for (var hOut = 0; hOut < outHeight; ++hOut) {
var hBeg = hOut * strideHeight - padInfo.top;
for (var wOut = 0; wOut < outWidth; ++wOut) {
var wBeg = wOut * strideWidth - padInfo.left;
for (var d = 0; d < inChannels; ++d) {
var curVal = Number.MIN_SAFE_INTEGER;
var hMax = 0;
var wMax = 0;
for (var h = 0; h < filterHeight; ++h) {
var hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (var w = 0; w < filterWidth; ++w) {
var wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
var val = $x[b][hIn][wIn][d] + $filter[h][w][d];
if (val > curVal) {
curVal = val;
hMax = h;
wMax = w;
}
}
}
}
}
gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
}
}
}
}
var dataId = cpuBackend.write(tf.util.toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
return {dataId, shape: filter.shape, dtype: filter.dtype};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var dilation2dBackpropInputConfig = {
kernelName: tf.Dilation2DBackpropInput,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
var _b = inputs, x = _b.x, filter = _b.filter, dy = _b.dy;
var _c = attrs, strides = _c.strides, pad = _c.pad, dilations = _c.dilations;
var cpuBackend = backend;
var $x = tf.util.toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
var $filter = tf.util.toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
var _d = tf.backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, "NHWC", dilations), batchSize = _d.batchSize, inHeight = _d.inHeight, inWidth = _d.inWidth, inChannels = _d.inChannels, outHeight = _d.outHeight, outWidth = _d.outWidth, padInfo = _d.padInfo, strideHeight = _d.strideHeight, strideWidth = _d.strideWidth, filterHeight = _d.filterHeight, filterWidth = _d.filterWidth, dilationHeight = _d.dilationHeight, dilationWidth = _d.dilationWidth, outShape = _d.outShape;
tf.util.assert(dy.rank === outShape.length, function() {
return "Error in " + tf.Dilation2DBackpropInput + ", dy " + ("must have the same rank as output " + outShape.length + ", but got ") + ("" + dy.rank);
});
var $dy = tf.util.toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
var gradients = tf.util.makeZerosNestedTypedArray(x.shape, x.dtype);
for (var b = 0; b < batchSize; ++b) {
for (var hOut = 0; hOut < outHeight; ++hOut) {
var hBeg = hOut * strideHeight - padInfo.top;
for (var wOut = 0; wOut < outWidth; ++wOut) {
var wBeg = wOut * strideWidth - padInfo.left;
for (var d = 0; d < inChannels; ++d) {
var curVal = Number.MIN_SAFE_INTEGER;
var hInMax = hBeg < 0 ? 0 : hBeg;
var wInMax = wBeg < 0 ? 0 : wBeg;
for (var h = 0; h < filterHeight; ++h) {
var hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (var w = 0; w < filterWidth; ++w) {
var wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
var val = $x[b][hIn][wIn][d] + $filter[h][w][d];
if (val > curVal) {
curVal = val;
hInMax = hIn;
wInMax = wIn;
}
}
}
}
}
gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
}
}
}
}
var dataId = cpuBackend.write(tf.util.toTypedArray(gradients, x.dtype), x.shape, x.dtype);
return {dataId, shape: x.shape, dtype: x.dtype};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var divImpl = createSimpleBinaryKernelImpl(function(a, b) {
return a / b;
});
var div = binaryKernelFunc(tf.Div, divImpl);
var divConfig = {
kernelName: tf.Div,
backendName: "cpu",
kernelFunc: div
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var p = tf.backend_util.ERF_P;
var a1 = tf.backend_util.ERF_A1;
var a2 = tf.backend_util.ERF_A2;
var a3 = tf.backend_util.ERF_A3;
var a4 = tf.backend_util.ERF_A4;
var a5 = tf.backend_util.ERF_A5;
var erf = unaryKernelFunc(tf.Erf, function(xi) {
var sign2 = Math.sign(xi);
var v = Math.abs(xi);
var t = 1 / (1 + p * v);
return sign2 * (1 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-v * v));
});
var erfConfig = {
kernelName: tf.Erf,
backendName: "cpu",
kernelFunc: erf
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fftBatch(input, inverse, cpuBackend) {
var inputShape = input.shape;
var batch = inputShape[0];
var innerDim = inputShape[1];
var inputVals = cpuBackend.data.get(input.dataId);
var real2D = inputVals.complexTensorInfos.real;
var imag2D = inputVals.complexTensorInfos.imag;
var resultShape = [batch, innerDim];
var resultSize = tf.util.sizeFromShape(resultShape);
var resultReal = tf.util.getTypedArrayFromDType("float32", resultSize);
var resultImag = tf.util.getTypedArrayFromDType("float32", resultSize);
for (var b = 0; b < batch; b++) {
var r = slice({
inputs: {x: real2D},
backend: cpuBackend,
attrs: {begin: [b, 0], size: [1, innerDim]}
});
var i = slice({
inputs: {x: imag2D},
backend: cpuBackend,
attrs: {begin: [b, 0], size: [1, innerDim]}
});
var input_1 = complex({inputs: {real: r, imag: i}, backend: cpuBackend});
var _a = fftImpl(input_1, inverse, cpuBackend), real_1 = _a.real, imag_1 = _a.imag;
var res = tf.backend_util.mergeRealAndImagArrays(real_1, imag_1);
for (var d = 0; d < innerDim; d++) {
var c = tf.backend_util.getComplexWithIndex(res, d);
resultReal[b * innerDim + d] = c.real;
resultImag[b * innerDim + d] = c.imag;
}
cpuBackend.disposeIntermediateTensorInfo(r);
cpuBackend.disposeIntermediateTensorInfo(i);
cpuBackend.disposeIntermediateTensorInfo(input_1);
}
var $realInfo = cpuBackend.makeTensorInfo(resultShape, "float32", resultReal);
var $imagInfo = cpuBackend.makeTensorInfo(resultShape, "float32", resultImag);
var result = complex({inputs: {real: $realInfo, imag: $imagInfo}, backend: cpuBackend});
cpuBackend.disposeIntermediateTensorInfo($realInfo);
cpuBackend.disposeIntermediateTensorInfo($imagInfo);
return result;
}
function fftImpl(input, inverse, cpuBackend) {
var inputSize = tf.util.sizeFromShape(input.shape);
var inputVals = cpuBackend.data.get(input.dataId);
var realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
var imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
if (isExponentOf2(inputSize)) {
var result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
var resultShape = [input.shape[0], input.shape[1]];
if (inverse) {
var realInfo = cpuBackend.makeTensorInfo(resultShape, "float32", result.real);
var imagInfo = cpuBackend.makeTensorInfo(resultShape, "float32", result.imag);
var sizeInfo = cpuBackend.makeTensorInfo([], "float32", tf.util.createScalarValue(inputSize, "float32"));
var sizeInfoCopy = identity({inputs: {x: sizeInfo}, backend: cpuBackend});
var divRealInfo = divConfig.kernelFunc({inputs: {a: realInfo, b: sizeInfo}, backend: cpuBackend});
var divImagInfo = divConfig.kernelFunc({inputs: {a: imagInfo, b: sizeInfoCopy}, backend: cpuBackend});
var divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
var divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
cpuBackend.disposeIntermediateTensorInfo(realInfo);
cpuBackend.disposeIntermediateTensorInfo(imagInfo);
cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
return {real: divRealVals, imag: divImagVals};
}
return result;
} else {
var data = tf.backend_util.mergeRealAndImagArrays(realVals, imagVals);
var rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
return tf.backend_util.splitRealAndImagArrays(rawOutput);
}
}
function isExponentOf2(size) {
return (size & size - 1) === 0;
}
function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
if (size === 1) {
return {real: realVals, imag: imagVals};
}
var data = tf.backend_util.mergeRealAndImagArrays(realVals, imagVals);
var half = size / 2;
var evenComplex = tf.backend_util.complexWithEvenIndex(data);
var evenRealVals = evenComplex.real;
var evenImagVals = evenComplex.imag;
var evenShape = [evenRealVals.length];
var evenRealInfo = cpuBackend.makeTensorInfo(evenShape, "float32", evenRealVals);
var evenImagInfo = cpuBackend.makeTensorInfo(evenShape, "float32", evenImagVals);
var evenTensorInfo = complex({inputs: {real: evenRealInfo, imag: evenImagInfo}, backend: cpuBackend});
var oddComplex = tf.backend_util.complexWithOddIndex(data);
var oddRealVals = oddComplex.real;
var oddImagVals = oddComplex.imag;
var oddShape = [oddRealVals.length];
var oddRealInfo = cpuBackend.makeTensorInfo(oddShape, "float32", oddRealVals);
var oddImagInfo = cpuBackend.makeTensorInfo(oddShape, "float32", oddImagVals);
var oddTensorInfo = complex({inputs: {real: oddRealInfo, imag: oddImagInfo}, backend: cpuBackend});
var $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
var $evenRealVals = $evenComplex.real;
var $evenImagVals = $evenComplex.imag;
var $evenShape = [$evenRealVals.length];
var $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, "float32", $evenRealVals);
var $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, "float32", $evenImagVals);
var $evenTensorInfo = complex({
inputs: {real: $evenRealInfo, imag: $evenImagInfo},
backend: cpuBackend
});
var $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
var $oddRealVals = $oddComplex.real;
var $oddImagVals = $oddComplex.imag;
var $oddShape = [$oddRealVals.length];
var $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, "float32", $oddRealVals);
var $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, "float32", $oddImagVals);
var $oddTensorInfo = complex({inputs: {real: $oddRealInfo, imag: $oddImagInfo}, backend: cpuBackend});
var e = tf.backend_util.exponents(size, inverse);
var eShape = [e.real.length];
var eRealInfo = cpuBackend.makeTensorInfo(eShape, "float32", e.real);
var eImagInfo = cpuBackend.makeTensorInfo(eShape, "float32", e.imag);
var complexInfo = complex({inputs: {real: eRealInfo, imag: eImagInfo}, backend: cpuBackend});
var exponentInfo = multiply({inputs: {a: complexInfo, b: $oddTensorInfo}, backend: cpuBackend});
var addPart = add({
inputs: {a: $evenTensorInfo, b: exponentInfo},
backend: cpuBackend
});
var subPart = sub({
inputs: {a: $evenTensorInfo, b: exponentInfo},
backend: cpuBackend
});
var addPartReal = real({inputs: {input: addPart}, backend: cpuBackend});
var subPartReal = real({inputs: {input: subPart}, backend: cpuBackend});
var addPartImag = imag({inputs: {input: addPart}, backend: cpuBackend});
var subPartImag = imag({inputs: {input: subPart}, backend: cpuBackend});
var $real = concat({
inputs: [addPartReal, subPartReal],
backend: cpuBackend,
attrs: {axis: 0}
});
var $imag = concat({
inputs: [addPartImag, subPartImag],
backend: cpuBackend,
attrs: {axis: 0}
});
var $realVals = cpuBackend.data.get($real.dataId).values;
var $imagVals = cpuBackend.data.get($imag.dataId).values;
cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
cpuBackend.disposeIntermediateTensorInfo(complexInfo);
cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
cpuBackend.disposeIntermediateTensorInfo(addPart);
cpuBackend.disposeIntermediateTensorInfo(subPart);
cpuBackend.disposeIntermediateTensorInfo(addPartReal);
cpuBackend.disposeIntermediateTensorInfo(addPartImag);
cpuBackend.disposeIntermediateTensorInfo(subPartReal);
cpuBackend.disposeIntermediateTensorInfo(subPartImag);
cpuBackend.disposeIntermediateTensorInfo($real);
cpuBackend.disposeIntermediateTensorInfo($imag);
return {real: $realVals, imag: $imagVals};
}
function fourierTransformByMatmul(data, size, inverse) {
var ret = new Float32Array(size * 2);
for (var r = 0; r < size; r++) {
var real_2 = 0;
var imag_2 = 0;
for (var c = 0; c < size; c++) {
var e = tf.backend_util.exponent(r * c, size, inverse);
var term = tf.backend_util.getComplexWithIndex(data, c);
real_2 += term.real * e.real - term.imag * e.imag;
imag_2 += term.real * e.imag + term.imag * e.real;
}
if (inverse) {
real_2 /= size;
imag_2 /= size;
}
tf.backend_util.assignToTypedArray(ret, real_2, imag_2, r);
}
return ret;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fft(args) {
var inputs = args.inputs, backend = args.backend;
var input = inputs.input;
var inputSize = tf.util.sizeFromShape(input.shape);
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = inputSize / innerDimensionSize;
var input2D = reshape({
inputs: {x: input},
backend,
attrs: {shape: [batch, innerDimensionSize]}
});
var result = fftBatch(input2D, false, backend);
var resultReshaped = reshape({inputs: {x: result}, backend, attrs: {shape: input.shape}});
backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
var fftConfig = {
kernelName: tf.FFT,
backendName: "cpu",
kernelFunc: fft
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fill(args) {
var backend = args.backend, attrs = args.attrs;
var shape = attrs.shape, value = attrs.value, dtype = attrs.dtype;
var $dtype = dtype || tf.util.inferDtype(value);
var values = tf.util.getArrayFromDType($dtype, tf.util.sizeFromShape(shape));
fillValues(values, value, $dtype);
return backend.makeTensorInfo(shape, $dtype, values);
}
var fillConfig = {
kernelName: tf.Fill,
backendName: "cpu",
kernelFunc: fill
};
function fillValues(values, value, dtype) {
if (dtype === "string") {
values.fill(value);
} else {
values.fill(value);
}
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var flipLeftRightConfig = {
kernelName: tf.FlipLeftRight,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var image = inputs.image;
var cpuBackend = backend;
var output = tf.util.getTypedArrayFromDType(image.dtype, tf.util.sizeFromShape(image.shape));
var _b = image.shape, batch = _b[0], imageHeight = _b[1], imageWidth = _b[2], numChannels = _b[3];
var imageVals = cpuBackend.data.get(image.dataId).values;
for (var batchIdx = 0; batchIdx < batch; batchIdx++) {
var batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
for (var row = 0; row < imageHeight; row++) {
var rowOffset = row * (imageWidth * numChannels);
for (var col = 0; col < imageWidth; col++) {
var colOffset = col * numChannels;
for (var channel = 0; channel < numChannels; channel++) {
var coords = [batch, row, col, channel];
var x = coords[2];
var coordX = Math.round(imageWidth - x);
var outIdx = batchOffset + rowOffset + colOffset + channel;
var outputValue = imageVals[outIdx];
if (coordX >= 0 && coordX < imageWidth) {
var rotatedColOffset = coordX * numChannels;
var imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
outputValue = imageVals[imageIdx];
}
output[outIdx] = outputValue;
}
}
}
}
var dataId = cpuBackend.write(output, image.shape, image.dtype);
return {dataId, shape: image.shape, dtype: image.dtype};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedConv2D(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, filter = inputs.filter, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights;
var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode, activation = attrs.activation;
var result = conv2D({
inputs: {x, filter},
backend,
attrs: {strides, pad, dataFormat, dilations, dimRoundingMode}
});
if (bias) {
var resultOld = result;
result = add({inputs: {a: result, b: bias}, backend});
backend.disposeIntermediateTensorInfo(resultOld);
}
if (activation) {
var resultOld = result;
result = applyActivation(backend, result, activation, preluActivationWeights);
backend.disposeIntermediateTensorInfo(resultOld);
}
return result;
}
var fusedConv2DConfig = {
kernelName: tf.FusedConv2D,
backendName: "cpu",
kernelFunc: fusedConv2D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedDepthwiseConv2D(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x, filter = inputs.filter, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights;
var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode, activation = attrs.activation;
var result = depthwiseConv2dNative({
inputs: {x, filter},
backend,
attrs: {strides, pad, dataFormat, dilations, dimRoundingMode}
});
if (bias) {
var oldResult = result;
result = add({inputs: {a: result, b: bias}, backend});
backend.disposeIntermediateTensorInfo(oldResult);
}
if (activation) {
var oldResult = result;
result = applyActivation(backend, result, activation, preluActivationWeights);
backend.disposeIntermediateTensorInfo(oldResult);
}
return result;
}
var fusedDepthwiseConv2DConfig = {
kernelName: tf.FusedDepthwiseConv2D,
backendName: "cpu",
kernelFunc: fusedDepthwiseConv2D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function ifft(args) {
var inputs = args.inputs, backend = args.backend;
var input = inputs.input;
var inputSize = tf.util.sizeFromShape(input.shape);
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = inputSize / innerDimensionSize;
var input2D = reshape({
inputs: {x: input},
backend,
attrs: {shape: [batch, innerDimensionSize]}
});
var result = fftBatch(input2D, true, backend);
var resultReshaped = reshape({inputs: {x: result}, backend, attrs: {shape: input.shape}});
backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
var ifftConfig = {
kernelName: tf.IFFT,
backendName: "cpu",
kernelFunc: ifft
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isFinite2 = unaryKernelFunc(tf.IsFinite, function(xi) {
return Number.isFinite(xi) ? 1 : 0;
}, "bool");
var isFiniteConfig = {
kernelName: tf.IsFinite,
backendName: "cpu",
kernelFunc: isFinite2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isInf = unaryKernelFunc(tf.IsInf, function(xi) {
return Math.abs(xi) === Infinity ? 1 : 0;
}, "bool");
var isInfConfig = {
kernelName: tf.IsInf,
backendName: "cpu",
kernelFunc: isInf
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isNaN$1 = unaryKernelFunc(tf.IsNan, function(xi) {
return Number.isNaN(xi) ? 1 : 0;
}, "bool");
var isNaNConfig = {
kernelName: tf.IsNan,
backendName: "cpu",
kernelFunc: isNaN$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var log1p = unaryKernelFunc(tf.Log1p, function(xi) {
return Math.log1p(xi);
});
var log1pConfig = {
kernelName: tf.Log1p,
backendName: "cpu",
kernelFunc: log1p
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logicalNot = unaryKernelFunc(tf.LogicalNot, function(xi) {
return xi ? 0 : 1;
}, "bool");
var logicalNotConfig = {
kernelName: tf.LogicalNot,
backendName: "cpu",
kernelFunc: logicalNot
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxConfig = {
kernelName: tf.Max,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var x = inputs.x;
var _b = attrs, reductionIndices = _b.reductionIndices, keepDims = _b.keepDims;
var cpuBackend = backend;
var xShape = x.shape;
var xRank = xShape.length;
var origAxes = tf.util.parseAxisParam(reductionIndices, xShape);
var axes = origAxes;
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
var xVals = cpuBackend.data.get(x.dataId).values;
if (permutedAxes != null) {
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = xShape[permutedAxes[i]];
}
xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape);
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
xShape = newShape;
}
assertNotComplex(x, "max");
tf.backend_util.assertAxesAreInnerMostDims("max", axes, xRank);
var _c = tf.backend_util.computeOutAndReduceShapes(xShape, axes), maxOutShape = _c[0], reduceShape = _c[1];
var reduceSize = tf.util.sizeFromShape(reduceShape);
var result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype);
var dataId = cpuBackend.write(result, maxOutShape, x.dtype);
var outShape = maxOutShape;
if (keepDims) {
var newShape = tf.backend_util.expandShapeToKeepDim(maxOutShape, origAxes);
outShape = newShape;
}
return {dataId, shape: outShape, dtype: x.dtype};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
assertNotComplex(x, "maxPool");
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
var dilations = 1;
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in maxPool: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
var res;
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
res = identity({inputs: {x}, backend});
} else {
var xValues = backend.data.get(x.dataId).values;
var strides_1 = tf.util.computeStrides(x.shape);
var buffer = pool(xValues, x.shape, x.dtype, strides_1, convInfo, "max");
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
}
return res;
}
var maxPoolConfig = {
kernelName: tf.MaxPool,
backendName: "cpu",
kernelFunc: maxPool
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolBackprop(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var dy = inputs.dy, input = inputs.input, output = inputs.output;
var x = input;
assertNotComplex([input, output], "maxPoolBackprop");
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1, pad, dimRoundingMode);
var xValues = backend.data.get(x.dataId).values;
var maxPosBuf = tf.buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var dx = tf.buffer(x.shape, "float32");
var dyData = backend.data.get(dy.dataId).values;
var dyBuf = tf.buffer(dy.shape, "float32", dyData);
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) {
for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) {
var dyRCorner = dxR - padTop;
var dyCCorner = dxC - padLeft;
var dotProd = 0;
for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
var dyR = (dyRCorner + wR) / strideHeight;
if (dyR < 0 || dyR >= convInfo.outHeight || Math.floor(dyR) !== dyR) {
continue;
}
for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
var dyC = (dyCCorner + wC) / strideWidth;
if (dyC < 0 || dyC >= convInfo.outWidth || Math.floor(dyC) !== dyC) {
continue;
}
var maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 - maxPosBuf.get(b, dyR, dyC, d);
var curPos = wR * effectiveFilterWidth + wC;
var mask = maxPos === curPos ? 1 : 0;
if (mask === 0) {
continue;
}
var pixel = dyBuf.get(b, dyR, dyC, d);
dotProd += pixel * mask;
}
}
dx.set(dotProd, b, dxR, dxC, d);
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var maxPoolBackpropConfig = {
kernelName: tf.MaxPoolBackprop,
backendName: "cpu",
kernelFunc: maxPoolBackprop
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolWithArgmaxImpl(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
var strides = tf.util.computeStrides(xShape);
var maxPools = pool(xValues, xShape, dtype, strides, convInfo, "max");
var maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
return [maxPools.values, maxPositions.values];
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxPoolWithArgmaxConfig = {
kernelName: tf.MaxPoolWithArgmax,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var x = inputs.x;
var _b = attrs, filterSize = _b.filterSize, strides = _b.strides, pad = _b.pad, includeBatchInIndex = _b.includeBatchInIndex;
var cpuBackend = backend;
assertNotComplex(x, "MaxPoolWithArgmax");
var values = cpuBackend.data.get(x.dataId).values;
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
var _c = maxPoolWithArgmaxImpl(values, x.shape, x.dtype, includeBatchInIndex, convInfo), pooled = _c[0], indexes = _c[1];
var pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
var indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
return [
{dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype},
{dataId: indexesDataId, shape: convInfo.outShape, dtype: "int32"}
];
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function mirrorPad(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var paddings = attrs.paddings, mode = attrs.mode;
assertNotComplex(x, "mirrorPad");
var outShape = paddings.map(function(p2, i2) {
return p2[0] + x.shape[i2] + p2[1];
});
var start = paddings.map(function(p2) {
return p2[0];
});
var end = paddings.map(function(p2, i2) {
return p2[0] + x.shape[i2];
});
var offset = mode === "reflect" ? 0 : 1;
var xVals = backend.data.get(x.dataId).values;
var xRank = x.shape.length;
var xStrides = tf.util.computeStrides(x.shape);
var resultSize = tf.util.sizeFromShape(outShape);
var resultRank = outShape.length;
var resultStrides = tf.util.computeStrides(outShape);
var resVals = tf.util.getTypedArrayFromDType(x.dtype, resultSize);
for (var i = 0; i < resultSize; i++) {
var coords = tf.util.indexToLoc(i, resultRank, resultStrides);
for (var i_1 = 0; i_1 < resultRank; i_1++) {
if (coords[i_1] < start[i_1]) {
coords[i_1] = start[i_1] * 2 - coords[i_1] - offset;
} else if (coords[i_1] >= end[i_1]) {
coords[i_1] = (end[i_1] - 1) * 2 - coords[i_1] + offset;
}
}
coords = coords.map(function(c, i2) {
return c - start[i2];
});
var inIndex = tf.util.locToIndex(coords, xRank, xStrides);
resVals[i] = xVals[inIndex];
}
var outId = backend.write(resVals, outShape, x.dtype);
return {dataId: outId, shape: outShape, dtype: x.dtype};
}
var mirrorPadConfig = {
kernelName: tf.MirrorPad,
backendName: "cpu",
kernelFunc: mirrorPad
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV4Impl = tf.kernel_impls.nonMaxSuppressionV4Impl;
var nonMaxSuppressionV4Config = {
kernelName: tf.NonMaxSuppressionV4,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
var _b = inputs, boxes = _b.boxes, scores = _b.scores;
var _c = attrs, maxOutputSize = _c.maxOutputSize, iouThreshold = _c.iouThreshold, scoreThreshold = _c.scoreThreshold, padToMaxOutputSize = _c.padToMaxOutputSize;
var cpuBackend = backend;
assertNotComplex(boxes, "NonMaxSuppressionPadded");
var boxesVals = cpuBackend.data.get(boxes.dataId).values;
var scoresVals = cpuBackend.data.get(scores.dataId).values;
var _d = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize), selectedIndices = _d.selectedIndices, validOutputs = _d.validOutputs;
return [selectedIndices, validOutputs];
}
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV5Impl = tf.kernel_impls.nonMaxSuppressionV5Impl;
var nonMaxSuppressionV5Config = {
kernelName: tf.NonMaxSuppressionV5,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
var _b = inputs, boxes = _b.boxes, scores = _b.scores;
var _c = attrs, maxOutputSize = _c.maxOutputSize, iouThreshold = _c.iouThreshold, scoreThreshold = _c.scoreThreshold, softNmsSigma = _c.softNmsSigma;
var cpuBackend = backend;
assertNotComplex(boxes, "NonMaxSuppressionWithScore");
var boxesVals = cpuBackend.data.get(boxes.dataId).values;
var scoresVals = cpuBackend.data.get(scores.dataId).values;
var maxOutputSizeVal = maxOutputSize;
var iouThresholdVal = iouThreshold;
var scoreThresholdVal = scoreThreshold;
var softNmsSigmaVal = softNmsSigma;
var _d = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal), selectedIndices = _d.selectedIndices, selectedScores = _d.selectedScores;
return [selectedIndices, selectedScores];
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function padV2(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var paddings = attrs.paddings, constantValue = attrs.constantValue;
assertNotComplex(x, "pad");
var outShape = paddings.map(function(p2, i2) {
return p2[0] + x.shape[i2] + p2[1];
});
var start = paddings.map(function(p2) {
return p2[0];
});
var xVals = backend.data.get(x.dataId).values;
var xSize = tf.util.sizeFromShape(x.shape);
var xRank = x.shape.length;
var xStrides = tf.util.computeStrides(x.shape);
var resultSize = tf.util.sizeFromShape(outShape);
var resultRank = outShape.length;
var resultStrides = tf.util.computeStrides(outShape);
var resVals = tf.util.getTypedArrayFromDType(x.dtype, resultSize);
if (constantValue !== 0) {
resVals.fill(constantValue);
}
for (var i = 0; i < xSize; i++) {
var coords = tf.util.indexToLoc(i, xRank, xStrides);
var outCoords = coords.map(function(c, i2) {
return c + start[i2];
});
var outIndex = tf.util.locToIndex(outCoords, resultRank, resultStrides);
resVals[outIndex] = xVals[i];
}
var outId = backend.write(resVals, outShape, x.dtype);
return {dataId: outId, shape: outShape, dtype: x.dtype};
}
var padV2Config = {
kernelName: tf.PadV2,
backendName: "cpu",
kernelFunc: padV2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reciprocal = unaryKernelFunc(tf.Reciprocal, function(xi) {
return 1 / xi;
});
var reciprocalConfig = {
kernelName: tf.Reciprocal,
backendName: "cpu",
kernelFunc: reciprocal
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var rotateWithOffsetConfig = {
kernelName: tf.RotateWithOffset,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var image = inputs.image;
var _b = attrs, radians = _b.radians, fillValue = _b.fillValue, center = _b.center;
var cpuBackend = backend;
var output = tf.util.getTypedArrayFromDType(image.dtype, tf.util.sizeFromShape(image.shape));
var _c = image.shape, batch = _c[0], imageHeight = _c[1], imageWidth = _c[2], numChannels = _c[3];
var _d = tf.backend_util.getImageCenter(center, imageHeight, imageWidth), centerX = _d[0], centerY = _d[1];
var fullOpacityValue = 255;
var sinFactor = Math.sin(radians);
var cosFactor = Math.cos(radians);
var imageVals = cpuBackend.data.get(image.dataId).values;
for (var batchIdx = 0; batchIdx < batch; batchIdx++) {
var batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
for (var row = 0; row < imageHeight; row++) {
var rowOffset = row * (imageWidth * numChannels);
for (var col = 0; col < imageWidth; col++) {
var colOffset = col * numChannels;
for (var channel = 0; channel < numChannels; channel++) {
var coords = [batch, row, col, channel];
var x = coords[2];
var y = coords[1];
var coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
var coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
coordX = Math.round(coordX + centerX);
coordY = Math.round(coordY + centerY);
var outputValue = fillValue;
if (typeof fillValue !== "number") {
if (channel === 3) {
outputValue = fullOpacityValue;
} else {
outputValue = fillValue[channel];
}
}
if (coordX >= 0 && coordX < imageWidth && coordY >= 0 && coordY < imageHeight) {
var rotatedRowOffset = coordY * (imageWidth * numChannels);
var rotatedColOffset = coordX * numChannels;
var imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
outputValue = imageVals[imageIdx];
}
var outIdx = batchOffset + rowOffset + colOffset + channel;
output[outIdx] = outputValue;
}
}
}
}
var dataId = cpuBackend.write(output, image.shape, image.dtype);
return {dataId, shape: image.shape, dtype: image.dtype};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var round = unaryKernelFunc(tf.Round, function(xi) {
var base = Math.floor(xi);
if (xi - base < 0.5) {
return Math.floor(xi);
} else if (xi - base > 0.5) {
return Math.ceil(xi);
} else {
if (base % 2 === 0) {
return base;
} else {
return base + 1;
}
}
});
var roundConfig = {
kernelName: tf.Round,
backendName: "cpu",
kernelFunc: round
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var scaleAlpha = tf.backend_util.SELU_SCALEALPHA;
var scale = tf.backend_util.SELU_SCALE;
var selu = unaryKernelFunc(tf.Selu, function(xi) {
if (xi >= 0) {
return scale * xi;
} else {
return scaleAlpha * (Math.exp(xi) - 1);
}
});
var seluConfig = {
kernelName: tf.Selu,
backendName: "cpu",
kernelFunc: selu
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sigmoid = unaryKernelFunc(tf.Sigmoid, function(xi) {
return 1 / (1 + Math.exp(-xi));
});
var sigmoidConfig = {
kernelName: tf.Sigmoid,
backendName: "cpu",
kernelFunc: sigmoid
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sign = unaryKernelFunc(tf.Sign, function(xi) {
if (xi < 0) {
return -1;
} else if (xi > 0) {
return 1;
} else {
return 0;
}
});
var signConfig = {
kernelName: tf.Sign,
backendName: "cpu",
kernelFunc: sign
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sin = unaryKernelFunc(tf.Sin, function(xi) {
return Math.sin(xi);
});
var sinConfig = {
kernelName: tf.Sin,
backendName: "cpu",
kernelFunc: sin
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sinh = unaryKernelFunc(tf.Sinh, function(xi) {
return Math.sinh(xi);
});
var sinhConfig = {
kernelName: tf.Sinh,
backendName: "cpu",
kernelFunc: sinh
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var epsilon = 11920928955078125e-23;
var threshold = Math.log(epsilon) + 2;
var softplus = unaryKernelFunc(tf.Softplus, function(xi) {
var tooLarge = xi > -threshold;
var tooSmall = xi < threshold;
var expX = Math.exp(xi);
var result;
if (tooSmall) {
result = expX;
} else if (tooLarge) {
result = xi;
} else {
result = Math.log(1 + expX);
}
return result;
});
var softplusConfig = {
kernelName: tf.Softplus,
backendName: "cpu",
kernelFunc: softplus
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transpose(args) {
var inputs = args.inputs, attrs = args.attrs, backend = args.backend;
var x = inputs.x;
var perm = attrs.perm;
assertNotComplex(x, "transpose");
var xRank = x.shape.length;
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}
var values = backend.data.get(x.dataId).values;
var result = transposeImpl(values, x.shape, x.dtype, perm, newShape);
var dataId = backend.write(result, newShape, x.dtype);
return {dataId, shape: newShape, dtype: x.dtype};
}
var transposeConfig = {
kernelName: tf.Transpose,
backendName: "cpu",
kernelFunc: transpose
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function spaceToBatchND(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var blockShape = attrs.blockShape, paddings = attrs.paddings;
assertNotComplex([x], "spaceToBatchND");
var prod = tf.util.sizeFromShape(blockShape);
var completePaddings = [[0, 0]];
completePaddings.push.apply(completePaddings, paddings);
for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
var paddedX = padV2Config.kernelFunc({
inputs: {x},
backend,
attrs: {paddings: completePaddings, constantValue: 0}
});
var reshapedPaddedShape = tf.backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
var permutedReshapedPaddedPermutation = tf.backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
var flattenShape = tf.backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
var reshapeInputs = {x: paddedX};
var reshapeAttrs = {shape: reshapedPaddedShape};
var paddedXReshaped = reshape({inputs: reshapeInputs, backend, attrs: reshapeAttrs});
var transposeInputs = {x: paddedXReshaped};
var transposeAttrs = {perm: permutedReshapedPaddedPermutation};
var paddedXT = transpose({inputs: transposeInputs, backend, attrs: transposeAttrs});
var resultReshapeInputs = {x: paddedXT};
var resultReshapeAttrs = {shape: flattenShape};
var result = reshape({inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs});
backend.disposeIntermediateTensorInfo(paddedX);
backend.disposeIntermediateTensorInfo(paddedXReshaped);
backend.disposeIntermediateTensorInfo(paddedXT);
return result;
}
var spaceToBatchNDConfig = {
kernelName: tf.SpaceToBatchND,
backendName: "cpu",
kernelFunc: spaceToBatchND
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sqrt = unaryKernelFunc(tf.Sqrt, function(xi) {
return Math.sqrt(xi);
});
var sqrtConfig = {
kernelName: tf.Sqrt,
backendName: "cpu",
kernelFunc: sqrt
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var squareConfig = {
kernelName: tf.Square,
backendName: "cpu",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend;
var x = inputs.x;
var cpuBackend = backend;
assertNotComplex(x, "square");
var values = cpuBackend.data.get(x.dataId).values;
var newValues = new Float32Array(values.length);
for (var i = 0; i < values.length; ++i) {
var value = values[i];
newValues[i] = value * value;
}
var dataId = cpuBackend.write(newValues, x.shape, x.dtype);
return {dataId, shape: x.shape, dtype: x.dtype};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var step = unaryKernelFunc(tf.Step, function(xi, attrs) {
var stepAttrs = attrs;
if (isNaN(xi)) {
return NaN;
} else {
return xi > 0 ? 1 : stepAttrs.alpha;
}
});
var stepConfig = {
kernelName: tf.Step,
backendName: "cpu",
kernelFunc: step
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tan = unaryKernelFunc(tf.Tan, function(xi) {
return Math.tan(xi);
});
var tanConfig = {
kernelName: tf.Tan,
backendName: "cpu",
kernelFunc: tan
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tanh = unaryKernelFunc(tf.Tanh, function(xi) {
return Math.tanh(xi);
});
var tanhConfig = {
kernelName: tf.Tanh,
backendName: "cpu",
kernelFunc: tanh
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unique(args) {
var inputs = args.inputs, attrs = args.attrs, backend = args.backend;
var axis = attrs.axis;
var x = inputs.x;
assertNotComplex(x, "unique");
var values = backend.data.get(x.dataId).values;
var _a = uniqueImpl(values, axis, x.shape, x.dtype), outputValues = _a.outputValues, outputShape = _a.outputShape, indices = _a.indices;
return [
backend.makeTensorInfo(outputShape, x.dtype, outputValues),
backend.makeTensorInfo([indices.length], "int32", indices)
];
}
var uniqueConfig = {
kernelName: tf.Unique,
backendName: "cpu",
kernelFunc: unique
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernelConfigs = [
_fusedMatMulConfig,
absConfig,
acosConfig,
acoshConfig,
addConfig,
asinConfig,
asinhConfig,
atanConfig,
atanhConfig,
avgPoolConfig,
avgPoolBackpropConfig,
batchMatMulConfig,
batchNormConfig,
castConfig,
ceilConfig,
clipConfig,
complexConfig,
concatConfig,
conv2DBackpropFilterConfig,
conv2DBackpropInputConfig,
conv2DConfig,
conv3DBackpropFilterV2Config,
conv3DBackpropInputV2Config,
conv3DConfig,
cosConfig,
coshConfig,
depthwiseConv2dNativeConfig,
depthwiseConv2dNativeBackpropFilterConfig,
depthwiseConv2dNativeBackpropInputConfig,
dilation2dConfig,
dilation2dBackpropInputConfig,
dilation2dBackpropFilterConfig,
divConfig,
eluConfig,
erfConfig,
expConfig,
expm1Config,
fftConfig,
fillConfig,
flipLeftRightConfig,
floorConfig,
fusedConv2DConfig,
fusedDepthwiseConv2DConfig,
identityConfig,
ifftConfig,
imagConfig,
isFiniteConfig,
isInfConfig,
isNaNConfig,
logConfig,
log1pConfig,
logicalNotConfig,
maxPoolConfig,
maxPoolBackpropConfig,
maxPoolWithArgmaxConfig,
maxConfig,
mirrorPadConfig,
multiplyConfig,
nonMaxSuppressionV4Config,
nonMaxSuppressionV5Config,
notEqualConfig,
padV2Config,
preluConfig,
realConfig,
reciprocalConfig,
reluConfig,
relu6Config,
reshapeConfig,
rotateWithOffsetConfig,
roundConfig,
rsqrtConfig,
seluConfig,
sigmoidConfig,
signConfig,
sinConfig,
sinhConfig,
sliceConfig,
softplusConfig,
spaceToBatchNDConfig,
sqrtConfig,
squareConfig,
squaredDifferenceConfig,
stepConfig,
subConfig,
tanConfig,
tanhConfig,
transposeConfig,
uniqueConfig
];
for (var _i = 0, kernelConfigs_1 = kernelConfigs; _i < kernelConfigs_1.length; _i++) {
var kernelConfig = kernelConfigs_1[_i];
tf.registerKernel(kernelConfig);
}
exports.MathBackendCPU = MathBackendCPU;
exports.shared = shared;
exports.version_cpu = version;
});
// node_modules/@tensorflow/tfjs-backend-webgl/dist/tf-backend-webgl.node.js
var require_tf_backend_webgl_node = __commonJS((exports) => {
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
"use strict";
Object.defineProperty(exports, "__esModule", {value: true});
var tf = require_tf_core_node();
/*! *****************************************************************************
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at http://www.apache.org/licenses/LICENSE-2.0
THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
MERCHANTABLITY OR NON-INFRINGEMENT.
See the Apache Version 2.0 License for specific language governing permissions
and limitations under the License.
***************************************************************************** */
var extendStatics = function(d, b) {
extendStatics = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(d2, b2) {
d2.__proto__ = b2;
} || function(d2, b2) {
for (var p in b2)
if (b2.hasOwnProperty(p))
d2[p] = b2[p];
};
return extendStatics(d, b);
};
function __extends(d, b) {
extendStatics(d, b);
function __() {
this.constructor = d;
}
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
}
function __awaiter(thisArg, _arguments, P, generator) {
function adopt(value) {
return value instanceof P ? value : new P(function(resolve) {
resolve(value);
});
}
return new (P || (P = Promise))(function(resolve, reject) {
function fulfilled(value) {
try {
step(generator.next(value));
} catch (e) {
reject(e);
}
}
function rejected(value) {
try {
step(generator["throw"](value));
} catch (e) {
reject(e);
}
}
function step(result) {
result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected);
}
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = {label: 0, sent: function() {
if (t[0] & 1)
throw t[1];
return t[1];
}, trys: [], ops: []}, f, y, t, g;
return g = {next: verb(0), throw: verb(1), return: verb(2)}, typeof Symbol === "function" && (g[Symbol.iterator] = function() {
return this;
}), g;
function verb(n) {
return function(v) {
return step([n, v]);
};
}
function step(op) {
if (f)
throw new TypeError("Generator is already executing.");
while (_)
try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done)
return t;
if (y = 0, t)
op = [op[0] & 2, t.value];
switch (op[0]) {
case 0:
case 1:
t = op;
break;
case 4:
_.label++;
return {value: op[1], done: false};
case 5:
_.label++;
y = op[1];
op = [0];
continue;
case 7:
op = _.ops.pop();
_.trys.pop();
continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) {
_ = 0;
continue;
}
if (op[0] === 3 && (!t || op[1] > t[0] && op[1] < t[3])) {
_.label = op[1];
break;
}
if (op[0] === 6 && _.label < t[1]) {
_.label = t[1];
t = op;
break;
}
if (t && _.label < t[2]) {
_.label = t[2];
_.ops.push(op);
break;
}
if (t[2])
_.ops.pop();
_.trys.pop();
continue;
}
op = body.call(thisArg, _);
} catch (e) {
op = [6, e];
y = 0;
} finally {
f = t = 0;
}
if (op[0] & 5)
throw op[1];
return {value: op[0] ? op[1] : void 0, done: true};
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var contexts = {};
var WEBGL_ATTRIBUTES = {
alpha: false,
antialias: false,
premultipliedAlpha: false,
preserveDrawingBuffer: false,
depth: false,
stencil: false,
failIfMajorPerformanceCaveat: true
};
function setWebGLContext(webGLVersion, gl) {
contexts[webGLVersion] = gl;
}
function getWebGLContext(webGLVersion) {
if (!(webGLVersion in contexts)) {
var newCtx = getWebGLRenderingContext(webGLVersion);
if (newCtx !== null) {
contexts[webGLVersion] = newCtx;
} else {
console.log("Could not get context for WebGL version", webGLVersion);
return null;
}
}
var gl = contexts[webGLVersion];
if (gl.isContextLost()) {
delete contexts[webGLVersion];
return getWebGLContext(webGLVersion);
}
gl.disable(gl.DEPTH_TEST);
gl.disable(gl.STENCIL_TEST);
gl.disable(gl.BLEND);
gl.disable(gl.DITHER);
gl.disable(gl.POLYGON_OFFSET_FILL);
gl.disable(gl.SAMPLE_COVERAGE);
gl.enable(gl.SCISSOR_TEST);
gl.enable(gl.CULL_FACE);
gl.cullFace(gl.BACK);
return contexts[webGLVersion];
}
function createCanvas(webGLVersion) {
if (typeof OffscreenCanvas !== "undefined" && webGLVersion === 2) {
return new OffscreenCanvas(300, 150);
} else if (typeof document !== "undefined") {
return document.createElement("canvas");
} else {
throw new Error("Cannot create a canvas in this context");
}
}
function getWebGLRenderingContext(webGLVersion) {
if (webGLVersion !== 1 && webGLVersion !== 2) {
throw new Error("Cannot get WebGL rendering context, WebGL is disabled.");
}
var canvas = createCanvas(webGLVersion);
canvas.addEventListener("webglcontextlost", function(ev) {
ev.preventDefault();
delete contexts[webGLVersion];
}, false);
if (webGLVersion === 1) {
return canvas.getContext("webgl", WEBGL_ATTRIBUTES) || canvas.getContext("experimental-webgl", WEBGL_ATTRIBUTES);
}
return canvas.getContext("webgl2", WEBGL_ATTRIBUTES);
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PackingScheme;
(function(PackingScheme2) {
PackingScheme2[PackingScheme2["DENSE"] = 0] = "DENSE";
PackingScheme2[PackingScheme2["SHARED_BATCH"] = 1] = "SHARED_BATCH";
})(PackingScheme || (PackingScheme = {}));
var TextureUsage;
(function(TextureUsage2) {
TextureUsage2[TextureUsage2["RENDER"] = 0] = "RENDER";
TextureUsage2[TextureUsage2["UPLOAD"] = 1] = "UPLOAD";
TextureUsage2[TextureUsage2["PIXELS"] = 2] = "PIXELS";
TextureUsage2[TextureUsage2["DOWNLOAD"] = 3] = "DOWNLOAD";
})(TextureUsage || (TextureUsage = {}));
var PhysicalTextureType;
(function(PhysicalTextureType2) {
PhysicalTextureType2[PhysicalTextureType2["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
PhysicalTextureType2[PhysicalTextureType2["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
PhysicalTextureType2[PhysicalTextureType2["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
PhysicalTextureType2[PhysicalTextureType2["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
PhysicalTextureType2[PhysicalTextureType2["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
})(PhysicalTextureType || (PhysicalTextureType = {}));
function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
return [columns, rows];
}
function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
return matrixSize * channelsPerTexture;
}
function getDenseTexShape(shape) {
var size = tf.util.sizeFromShape(shape);
var texelsNeeded = Math.ceil(size / 4);
return tf.util.sizeToSquarishShape(texelsNeeded);
}
function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
return [
Math.max(1, Math.ceil(columns / 2)),
Math.max(1, Math.ceil(rows / 2))
];
}
function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1];
return w * h * 4;
}
function getTextureConfig(gl, textureHalfFloatExtension) {
var glany = gl;
var internalFormatFloat;
var internalFormatHalfFloat;
var internalFormatPackedHalfFloat;
var internalFormatPackedFloat;
var textureFormatFloat;
var downloadTextureFormat;
var downloadUnpackNumChannels;
var defaultNumChannels;
var textureTypeHalfFloat;
var textureTypeFloat;
if (tf.env().getNumber("WEBGL_VERSION") === 2) {
internalFormatFloat = glany.R32F;
internalFormatHalfFloat = glany.R16F;
internalFormatPackedHalfFloat = glany.RGBA16F;
internalFormatPackedFloat = glany.RGBA32F;
textureFormatFloat = glany.RED;
downloadUnpackNumChannels = 4;
defaultNumChannels = 1;
textureTypeHalfFloat = glany.HALF_FLOAT;
textureTypeFloat = glany.FLOAT;
} else {
internalFormatFloat = gl.RGBA;
internalFormatHalfFloat = gl.RGBA;
internalFormatPackedHalfFloat = gl.RGBA;
internalFormatPackedFloat = glany.RGBA;
textureFormatFloat = gl.RGBA;
downloadUnpackNumChannels = 4;
defaultNumChannels = 4;
textureTypeHalfFloat = textureHalfFloatExtension != null ? textureHalfFloatExtension.HALF_FLOAT_OES : null;
textureTypeFloat = gl.FLOAT;
}
downloadTextureFormat = gl.RGBA;
return {
internalFormatFloat,
internalFormatHalfFloat,
internalFormatPackedHalfFloat,
internalFormatPackedFloat,
textureFormatFloat,
downloadTextureFormat,
downloadUnpackNumChannels,
defaultNumChannels,
textureTypeHalfFloat,
textureTypeFloat
};
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function callAndCheck(gl, func) {
var returnValue = func();
if (tf.env().getBool("DEBUG")) {
checkWebGLError(gl);
}
return returnValue;
}
function checkWebGLError(gl) {
var error = gl.getError();
if (error !== gl.NO_ERROR) {
throw new Error("WebGL Error: " + getWebGLErrorMessage(gl, error));
}
}
var MIN_FLOAT16 = 596e-10;
var MAX_FLOAT16 = 65504;
function canBeRepresented(num) {
if (tf.env().getBool("WEBGL_RENDER_FLOAT32_ENABLED") || num === 0 || MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16) {
return true;
}
return false;
}
function getWebGLErrorMessage(gl, status) {
switch (status) {
case gl.NO_ERROR:
return "NO_ERROR";
case gl.INVALID_ENUM:
return "INVALID_ENUM";
case gl.INVALID_VALUE:
return "INVALID_VALUE";
case gl.INVALID_OPERATION:
return "INVALID_OPERATION";
case gl.INVALID_FRAMEBUFFER_OPERATION:
return "INVALID_FRAMEBUFFER_OPERATION";
case gl.OUT_OF_MEMORY:
return "OUT_OF_MEMORY";
case gl.CONTEXT_LOST_WEBGL:
return "CONTEXT_LOST_WEBGL";
default:
return "Unknown error code " + status;
}
}
function getExtensionOrThrow(gl, extensionName) {
return throwIfNull(gl, function() {
return gl.getExtension(extensionName);
}, 'Extension "' + extensionName + '" not supported on this browser.');
}
function createVertexShader(gl, vertexShaderSource) {
var vertexShader = throwIfNull(gl, function() {
return gl.createShader(gl.VERTEX_SHADER);
}, "Unable to create vertex WebGLShader.");
callAndCheck(gl, function() {
return gl.shaderSource(vertexShader, vertexShaderSource);
});
callAndCheck(gl, function() {
return gl.compileShader(vertexShader);
});
if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
console.log(gl.getShaderInfoLog(vertexShader));
throw new Error("Failed to compile vertex shader.");
}
return vertexShader;
}
function createFragmentShader(gl, fragmentShaderSource) {
var fragmentShader = throwIfNull(gl, function() {
return gl.createShader(gl.FRAGMENT_SHADER);
}, "Unable to create fragment WebGLShader.");
callAndCheck(gl, function() {
return gl.shaderSource(fragmentShader, fragmentShaderSource);
});
callAndCheck(gl, function() {
return gl.compileShader(fragmentShader);
});
if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
throw new Error("Failed to compile fragment shader.");
}
return fragmentShader;
}
var lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
var lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
if (lineNumberRegexResult == null) {
console.log("Couldn't parse line number in error: " + shaderInfoLog);
console.log(shaderSource);
return;
}
var lineNumber = +lineNumberRegexResult[1];
var shaderLines = shaderSource.split("\n");
var pad = shaderLines.length.toString().length + 2;
var linesWithLineNumbers = shaderLines.map(function(line, lineNumber2) {
return tf.util.rightPad((lineNumber2 + 1).toString(), pad) + line;
});
var maxLineLength = 0;
for (var i = 0; i < linesWithLineNumbers.length; i++) {
maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
}
var beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
var errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
var afterErrorLines = linesWithLineNumbers.slice(lineNumber);
console.log(beforeErrorLines.join("\n"));
console.log(shaderInfoLog.split("\n")[0]);
console.log("%c " + tf.util.rightPad(errorLine[0], maxLineLength), "border:1px solid red; background-color:#e3d2d2; color:#a61717");
console.log(afterErrorLines.join("\n"));
}
function createProgram(gl) {
return throwIfNull(gl, function() {
return gl.createProgram();
}, "Unable to create WebGLProgram.");
}
function linkProgram(gl, program) {
callAndCheck(gl, function() {
return gl.linkProgram(program);
});
if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
console.log(gl.getProgramInfoLog(program));
throw new Error("Failed to link vertex and fragment shaders.");
}
}
function validateProgram(gl, program) {
callAndCheck(gl, function() {
return gl.validateProgram(program);
});
if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
console.log(gl.getProgramInfoLog(program));
throw new Error("Shader program validation failed.");
}
}
function createStaticVertexBuffer(gl, data) {
var buffer = throwIfNull(gl, function() {
return gl.createBuffer();
}, "Unable to create WebGLBuffer");
callAndCheck(gl, function() {
return gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
});
callAndCheck(gl, function() {
return gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW);
});
return buffer;
}
function createStaticIndexBuffer(gl, data) {
var buffer = throwIfNull(gl, function() {
return gl.createBuffer();
}, "Unable to create WebGLBuffer");
callAndCheck(gl, function() {
return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer);
});
callAndCheck(gl, function() {
return gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW);
});
return buffer;
}
function getNumChannels() {
if (tf.env().getNumber("WEBGL_VERSION") === 2) {
return 1;
}
return 4;
}
function createTexture(gl) {
return throwIfNull(gl, function() {
return gl.createTexture();
}, "Unable to create WebGLTexture.");
}
function validateTextureSize(width, height) {
var maxTextureSize = tf.env().getNumber("WEBGL_MAX_TEXTURE_SIZE");
if (width <= 0 || height <= 0) {
var requested = "[" + width + "x" + height + "]";
throw new Error("Requested texture size " + requested + " is invalid.");
}
if (width > maxTextureSize || height > maxTextureSize) {
var requested = "[" + width + "x" + height + "]";
var max = "[" + maxTextureSize + "x" + maxTextureSize + "]";
throw new Error("Requested texture size " + requested + " greater than WebGL maximum on this browser / GPU " + max + ".");
}
}
function createFramebuffer(gl) {
return throwIfNull(gl, function() {
return gl.createFramebuffer();
}, "Unable to create WebGLFramebuffer.");
}
function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
var loc = gl.getAttribLocation(program, attribute);
if (loc === -1) {
return false;
}
callAndCheck(gl, function() {
return gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
});
callAndCheck(gl, function() {
return gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes);
});
callAndCheck(gl, function() {
return gl.enableVertexAttribArray(loc);
});
return true;
}
function bindTextureUnit(gl, texture, textureUnit) {
validateTextureUnit(gl, textureUnit);
callAndCheck(gl, function() {
return gl.activeTexture(gl.TEXTURE0 + textureUnit);
});
callAndCheck(gl, function() {
return gl.bindTexture(gl.TEXTURE_2D, texture);
});
}
function unbindTextureUnit(gl, textureUnit) {
validateTextureUnit(gl, textureUnit);
callAndCheck(gl, function() {
return gl.activeTexture(gl.TEXTURE0 + textureUnit);
});
callAndCheck(gl, function() {
return gl.bindTexture(gl.TEXTURE_2D, null);
});
}
function getProgramUniformLocationOrThrow(gl, program, uniformName) {
return throwIfNull(gl, function() {
return gl.getUniformLocation(program, uniformName);
}, 'uniform "' + uniformName + '" not present in program.');
}
function getProgramUniformLocation(gl, program, uniformName) {
return gl.getUniformLocation(program, uniformName);
}
function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
callAndCheck(gl, function() {
return bindTextureUnit(gl, texture, textureUnit);
});
callAndCheck(gl, function() {
return gl.uniform1i(uniformSamplerLocation, textureUnit);
});
}
function bindCanvasToFramebuffer(gl) {
callAndCheck(gl, function() {
return gl.bindFramebuffer(gl.FRAMEBUFFER, null);
});
callAndCheck(gl, function() {
return gl.viewport(0, 0, gl.canvas.width, gl.canvas.height);
});
callAndCheck(gl, function() {
return gl.scissor(0, 0, gl.canvas.width, gl.canvas.height);
});
}
function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
callAndCheck(gl, function() {
return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
});
callAndCheck(gl, function() {
return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
});
}
function unbindColorTextureFromFramebuffer(gl, framebuffer) {
callAndCheck(gl, function() {
return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
});
callAndCheck(gl, function() {
return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0);
});
}
function validateFramebuffer(gl) {
var status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
if (status !== gl.FRAMEBUFFER_COMPLETE) {
throw new Error("Error binding framebuffer: " + getFramebufferErrorMessage(gl, status));
}
}
function getFramebufferErrorMessage(gl, status) {
switch (status) {
case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
return "FRAMEBUFFER_INCOMPLETE_ATTACHMENT";
case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
return "FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT";
case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
return "FRAMEBUFFER_INCOMPLETE_DIMENSIONS";
case gl.FRAMEBUFFER_UNSUPPORTED:
return "FRAMEBUFFER_UNSUPPORTED";
default:
return "unknown error " + status;
}
}
function throwIfNull(gl, returnTOrNull, failureMessage) {
var tOrNull = callAndCheck(gl, function() {
return returnTOrNull();
});
if (tOrNull == null) {
throw new Error(failureMessage);
}
return tOrNull;
}
function validateTextureUnit(gl, textureUnit) {
var maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
var glTextureUnit = textureUnit + gl.TEXTURE0;
if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
var textureUnitRange = "[gl.TEXTURE0, gl.TEXTURE" + maxTextureUnit + "]";
throw new Error("textureUnit must be in " + textureUnitRange + ".");
}
}
function getBatchDim(shape, dimsToSkip) {
if (dimsToSkip === void 0) {
dimsToSkip = 2;
}
return tf.util.sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
}
function getRowsCols(shape) {
if (shape.length === 0) {
throw Error("Cannot get rows and columns of an empty shape array.");
}
return [
shape.length > 1 ? shape[shape.length - 2] : 1,
shape[shape.length - 1]
];
}
function getShapeAs3D(shape) {
var shapeAs3D = [1, 1, 1];
var isScalar = shape.length === 0 || shape.length === 1 && shape[0] === 1;
if (!isScalar) {
shapeAs3D = [getBatchDim(shape)].concat(getRowsCols(shape));
}
return shapeAs3D;
}
function getTextureShapeFromLogicalShape(logShape, isPacked) {
var _a;
if (isPacked === void 0) {
isPacked = false;
}
var maxTexSize = tf.env().getNumber("WEBGL_MAX_TEXTURE_SIZE");
if (isPacked) {
maxTexSize = maxTexSize * 2;
logShape = logShape.map(function(d, i) {
return i >= logShape.length - 2 ? tf.util.nearestLargerEven(logShape[i]) : logShape[i];
});
if (logShape.length === 1) {
logShape = [2, logShape[0]];
}
}
if (logShape.length !== 2) {
var squeezeResult = tf.util.squeezeShape(logShape);
logShape = squeezeResult.newShape;
}
var size = tf.util.sizeFromShape(logShape);
if (logShape.length <= 1 && size <= maxTexSize) {
return [1, size];
} else if (logShape.length === 2 && logShape[0] <= maxTexSize && logShape[1] <= maxTexSize) {
return logShape;
} else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && logShape[2] <= maxTexSize) {
return [logShape[0] * logShape[1], logShape[2]];
} else if (logShape.length === 3 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] <= maxTexSize) {
return [logShape[0], logShape[1] * logShape[2]];
} else if (logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTexSize && logShape[3] <= maxTexSize) {
return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
} else if (logShape.length === 4 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
} else {
if (isPacked) {
var batchDim = getBatchDim(logShape);
var rows = 2, cols = 2;
if (logShape.length) {
_a = getRowsCols(logShape), rows = _a[0], cols = _a[1];
}
size = batchDim * (rows / 2) * (cols / 2);
return tf.util.sizeToSquarishShape(size).map(function(d) {
return d * 2;
});
}
return tf.util.sizeToSquarishShape(size);
}
}
function isEven(n) {
return n % 2 === 0;
}
function isReshapeFree(shape1, shape2) {
shape1 = shape1.slice(-2);
shape2 = shape2.slice(-2);
if (tf.util.arraysEqual(shape1, shape2)) {
return true;
}
if (!shape1.length || !shape2.length) {
return true;
}
if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 || shape2[1] === 0) {
return true;
}
if (shape1.length !== shape2.length) {
var shape1Cols = shape1.slice(-1)[0];
var shape2Cols = shape2.slice(-1)[0];
if (shape1Cols === shape2Cols) {
return true;
}
if (isEven(shape1Cols) && isEven(shape2Cols) && (shape1[0] === 1 || shape2[0] === 1)) {
return true;
}
}
return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
}
var MAX_TEXTURE_SIZE;
var MAX_TEXTURES_IN_SHADER;
function getWebGLMaxTextureSize(webGLVersion) {
if (MAX_TEXTURE_SIZE == null) {
var gl = getWebGLContext(webGLVersion);
MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
}
return MAX_TEXTURE_SIZE;
}
function resetMaxTextureSize() {
MAX_TEXTURE_SIZE = null;
}
function resetMaxTexturesInShader() {
MAX_TEXTURES_IN_SHADER = null;
}
function getMaxTexturesInShader(webGLVersion) {
if (MAX_TEXTURES_IN_SHADER == null) {
var gl = getWebGLContext(webGLVersion);
MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
}
return Math.min(16, MAX_TEXTURES_IN_SHADER);
}
function getWebGLDisjointQueryTimerVersion(webGLVersion) {
if (webGLVersion === 0) {
return 0;
}
var queryTimerVersion;
var gl = getWebGLContext(webGLVersion);
if (hasExtension(gl, "EXT_disjoint_timer_query_webgl2") && webGLVersion === 2) {
queryTimerVersion = 2;
} else if (hasExtension(gl, "EXT_disjoint_timer_query")) {
queryTimerVersion = 1;
} else {
queryTimerVersion = 0;
}
return queryTimerVersion;
}
function hasExtension(gl, extensionName) {
var ext = gl.getExtension(extensionName);
return ext != null;
}
function isWebGLVersionEnabled(webGLVersion) {
try {
var gl = getWebGLContext(webGLVersion);
if (gl != null) {
return true;
}
} catch (e) {
console.log("Error when getting WebGL context: ", e);
return false;
}
return false;
}
function isCapableOfRenderingToFloatTexture(webGLVersion) {
if (webGLVersion === 0) {
return false;
}
var gl = getWebGLContext(webGLVersion);
if (webGLVersion === 1) {
if (!hasExtension(gl, "OES_texture_float")) {
return false;
}
} else {
if (!hasExtension(gl, "EXT_color_buffer_float")) {
return false;
}
}
var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
return isFrameBufferComplete;
}
function isDownloadFloatTextureEnabled(webGLVersion) {
if (webGLVersion === 0) {
return false;
}
var gl = getWebGLContext(webGLVersion);
if (webGLVersion === 1) {
if (!hasExtension(gl, "OES_texture_float")) {
return false;
}
if (!hasExtension(gl, "WEBGL_color_buffer_float")) {
return false;
}
} else {
if (hasExtension(gl, "EXT_color_buffer_float")) {
return createFloatTextureAndBindToFramebuffer(gl);
}
var COLOR_BUFFER_HALF_FLOAT = "EXT_color_buffer_half_float";
if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
var textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
}
return false;
}
var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
return isFrameBufferComplete;
}
function createFloatTextureAndBindToFramebuffer(gl) {
var texConfig = getTextureConfig(gl);
var texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
var width = 1;
var height = 1;
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
var frameBuffer = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
gl.bindTexture(gl.TEXTURE_2D, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteTexture(texture);
gl.deleteFramebuffer(frameBuffer);
return isFrameBufferComplete;
}
function createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension) {
var texConfig = getTextureConfig(gl, textureHalfFloatExtension);
var texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
var width = 1;
var height = 1;
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
var frameBuffer = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
gl.bindTexture(gl.TEXTURE_2D, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteTexture(texture);
gl.deleteFramebuffer(frameBuffer);
return isFrameBufferComplete;
}
function isWebGLFenceEnabled(webGLVersion) {
if (webGLVersion !== 2) {
return false;
}
var gl = getWebGLContext(webGLVersion);
var isEnabled = gl.fenceSync != null;
return isEnabled;
}
function assertNotComplex(tensor, opName) {
if (!Array.isArray(tensor)) {
tensor = [tensor];
}
tensor.forEach(function(t) {
if (t != null) {
tf.util.assert(t.dtype !== "complex64", function() {
return opName + " does not support complex64 tensors in the WebGL backend.";
});
}
});
}
var webgl_util = {
__proto__: null,
callAndCheck,
canBeRepresented,
getWebGLErrorMessage,
getExtensionOrThrow,
createVertexShader,
createFragmentShader,
createProgram,
linkProgram,
validateProgram,
createStaticVertexBuffer,
createStaticIndexBuffer,
getNumChannels,
createTexture,
validateTextureSize,
createFramebuffer,
bindVertexBufferToProgramAttribute,
bindTextureUnit,
unbindTextureUnit,
getProgramUniformLocationOrThrow,
getProgramUniformLocation,
bindTextureToProgramUniformSampler,
bindCanvasToFramebuffer,
bindColorTextureToFramebuffer,
unbindColorTextureFromFramebuffer,
validateFramebuffer,
getFramebufferErrorMessage,
getBatchDim,
getRowsCols,
getShapeAs3D,
getTextureShapeFromLogicalShape,
isReshapeFree,
getWebGLMaxTextureSize,
resetMaxTextureSize,
resetMaxTexturesInShader,
getMaxTexturesInShader,
getWebGLDisjointQueryTimerVersion,
hasExtension,
isWebGLVersionEnabled,
isCapableOfRenderingToFloatTexture,
isDownloadFloatTextureEnabled,
isWebGLFenceEnabled,
assertNotComplex
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ENV = tf.env();
ENV.registerFlag("HAS_WEBGL", function() {
return ENV.getNumber("WEBGL_VERSION") > 0;
});
ENV.registerFlag("WEBGL_VERSION", function() {
if (isWebGLVersionEnabled(2)) {
return 2;
} else if (isWebGLVersionEnabled(1)) {
return 1;
}
return 0;
});
ENV.registerFlag("WEBGL_CHECK_NUMERICAL_PROBLEMS", function() {
return false;
});
ENV.registerFlag("WEBGL_BUFFER_SUPPORTED", function() {
return ENV.get("WEBGL_VERSION") === 2;
});
ENV.registerFlag("WEBGL_CPU_FORWARD", function() {
return true;
});
ENV.registerFlag("WEBGL_FORCE_F16_TEXTURES", function() {
return false;
});
ENV.registerFlag("WEBGL_PACK", function() {
return ENV.getBool("HAS_WEBGL");
});
ENV.registerFlag("WEBGL_PACK_NORMALIZATION", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_PACK_CLIP", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_PACK_DEPTHWISECONV", function() {
return false;
});
ENV.registerFlag("WEBGL_PACK_BINARY_OPERATIONS", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_PACK_UNARY_OPERATIONS", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_PACK_ARRAY_OPERATIONS", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_PACK_IMAGE_OPERATIONS", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_PACK_REDUCE", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_LAZILY_UNPACK", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_CONV_IM2COL", function() {
return ENV.getBool("WEBGL_PACK");
});
ENV.registerFlag("WEBGL_MAX_TEXTURE_SIZE", function() {
return getWebGLMaxTextureSize(ENV.getNumber("WEBGL_VERSION"));
});
ENV.registerFlag("WEBGL_MAX_TEXTURES_IN_SHADER", function() {
return getMaxTexturesInShader(ENV.getNumber("WEBGL_VERSION"));
});
ENV.registerFlag("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION", function() {
var webGLVersion = ENV.getNumber("WEBGL_VERSION");
if (webGLVersion === 0) {
return 0;
}
return getWebGLDisjointQueryTimerVersion(webGLVersion);
});
ENV.registerFlag("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE", function() {
return ENV.getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION") > 0 && !tf.device_util.isMobile();
});
ENV.registerFlag("WEBGL_RENDER_FLOAT32_CAPABLE", function() {
return isCapableOfRenderingToFloatTexture(ENV.getNumber("WEBGL_VERSION"));
});
ENV.registerFlag("WEBGL_RENDER_FLOAT32_ENABLED", function() {
return ENV.getBool("WEBGL_FORCE_F16_TEXTURES") ? false : ENV.getBool("WEBGL_RENDER_FLOAT32_CAPABLE");
});
ENV.registerFlag("WEBGL_DOWNLOAD_FLOAT_ENABLED", function() {
return isDownloadFloatTextureEnabled(ENV.getNumber("WEBGL_VERSION"));
});
ENV.registerFlag("WEBGL_FENCE_API_ENABLED", function() {
return isWebGLFenceEnabled(ENV.getNumber("WEBGL_VERSION"));
});
ENV.registerFlag("WEBGL_SIZE_UPLOAD_UNIFORM", function() {
var useUniforms = ENV.getBool("WEBGL_RENDER_FLOAT32_ENABLED");
return useUniforms ? 4 : 0;
});
ENV.registerFlag("WEBGL_DELETE_TEXTURE_THRESHOLD", function() {
return -1;
}, function(threshold) {
if (threshold < 0 && threshold !== -1) {
throw new Error("WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never " + ("delete) or at least 0, but got " + threshold + "."));
}
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function simpleAbsImpl(vals) {
const resultValues = new Float32Array(vals.length);
for (let i = 0; i < vals.length; ++i) {
resultValues[i] = Math.abs(vals[i]);
}
return resultValues;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createSimpleBinaryKernelImpl(op) {
return (aShape, bShape, aVals, bVals, dtype) => {
const newShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
const resultRank = newShape.length;
const resultStrides = tf.util.computeStrides(newShape);
const resultSize = tf.util.sizeFromShape(newShape);
const result = tf.util.getTypedArrayFromDType(dtype, resultSize);
const aRank = aShape.length;
const bRank = bShape.length;
const aStrides = tf.util.computeStrides(aShape);
const bStrides = tf.util.computeStrides(bShape);
const aBroadcastDims = tf.backend_util.getBroadcastDims(aShape, newShape);
const bBroadcastDims = tf.backend_util.getBroadcastDims(bShape, newShape);
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (let i = 0; i < result.length; ++i) {
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
}
} else {
for (let i = 0; i < result.length; ++i) {
const loc = tf.util.indexToLoc(i, resultRank, resultStrides);
const aLoc = loc.slice(-aRank);
aBroadcastDims.forEach((d) => aLoc[d] = 0);
const aIndex = tf.util.locToIndex(aLoc, aRank, aStrides);
const bLoc = loc.slice(-bRank);
bBroadcastDims.forEach((d) => bLoc[d] = 0);
const bIndex = tf.util.locToIndex(bLoc, bRank, bStrides);
result[i] = op(aVals[aIndex], bVals[bIndex]);
}
}
return [result, newShape];
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const addImpl = createSimpleBinaryKernelImpl((a, b) => a + b);
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createSimpleUnaryImpl(op) {
return (values, dtype, attrs) => {
const newValues = tf.util.getTypedArrayFromDType(dtype, values.length);
for (let i = 0; i < values.length; ++i) {
newValues[i] = op(values[i], attrs);
}
return newValues;
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi));
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi));
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi));
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi));
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi));
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxImpl(aVals, reduceSize, outShape, dtype) {
const vals = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(outShape));
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let max = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (value > max) {
max = value;
}
}
vals[i] = max;
}
return vals;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const multiplyImpl = createSimpleBinaryKernelImpl((aValue, bValue) => aValue * bValue);
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi));
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sliceImpl(vals, begin, size, shape, dtype) {
const isContinous = tf.slice_util.isSliceContinous(shape, begin, size);
const length = tf.util.sizeFromShape(size);
const xStrides = tf.util.computeStrides(shape);
if (isContinous) {
const flatOffset = tf.slice_util.computeFlatOffset(begin, xStrides);
return vals.subarray(flatOffset, flatOffset + length);
}
const outVals = tf.util.getTypedArrayFromDType(dtype, length);
for (let i = 0; i < length; ++i) {
const rank = size.length;
const strides = tf.util.computeStrides(size);
const loc = tf.util.indexToLoc(i, rank, strides);
const xLoc = loc.map((idx, j) => idx + begin[j]);
const xIndex = tf.util.locToIndex(xLoc, shape.length, xStrides);
outVals[i] = vals[xIndex];
}
return outVals;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const subImpl = createSimpleBinaryKernelImpl((aValue, bValue) => aValue - bValue);
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transposeImpl(xVals, xShape, dtype, perm, newShape) {
const xRank = xShape.length;
const xSize = tf.util.sizeFromShape(xShape);
const xStrides = tf.util.computeStrides(xShape);
const newStrides = tf.util.computeStrides(newShape);
const result = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(newShape));
for (let i = 0; i < xSize; ++i) {
const loc = tf.util.indexToLoc(i, xRank, xStrides);
const newLoc = new Array(loc.length);
for (let i2 = 0; i2 < newLoc.length; i2++) {
newLoc[i2] = loc[perm[i2]];
}
const newIndex = tf.util.locToIndex(newLoc, xRank, newStrides);
result[newIndex] = xVals[i];
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function uniqueImpl(values, axis, shape, dtype) {
const $axis = tf.util.parseAxisParam(axis, shape)[0];
const newShape = [1, shape[0], 1];
for (let i = 0; i < $axis; i++) {
newShape[0] *= shape[i];
}
newShape[1] = shape[$axis];
for (let i = $axis + 1; i < shape.length; i++) {
newShape[2] *= shape[i];
}
const uniqueElements = {};
const indices = new Int32Array(shape[$axis]);
const inputBuffer = new tf.TensorBuffer(newShape, dtype, values);
const uniqueIndices = [];
const is1DTensor = newShape[0] === 1 && newShape[2] === 1;
for (let i = 0; i < shape[$axis]; i++) {
let element;
if (is1DTensor) {
element = values[i].toString();
} else {
const axisValues = [];
for (let m = 0; m < newShape[0]; m++) {
for (let n = 0; n < newShape[2]; n++) {
axisValues.push(inputBuffer.get(m, i, n));
}
}
element = axisValues.join(",");
}
if (uniqueElements[element] !== void 0) {
indices[i] = uniqueElements[element];
} else {
const uniqueIndex = Object.keys(uniqueElements).length;
uniqueElements[element] = uniqueIndex;
indices[i] = uniqueIndex;
uniqueIndices.push(i);
}
}
const outputTmpShape = newShape.slice();
outputTmpShape[1] = Object.keys(uniqueElements).length;
const outputBuffer = new tf.TensorBuffer(outputTmpShape, dtype);
uniqueIndices.forEach((uniqueElementIndex, i) => {
for (let m = 0; m < newShape[0]; m++) {
for (let n = 0; n < newShape[2]; n++) {
outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n);
}
}
});
const outputShape = shape.slice();
outputShape[$axis] = outputTmpShape[1];
return {
outputValues: outputBuffer.values,
outputShape,
indices
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var simpleAbsImplCPU = simpleAbsImpl;
var addImplCPU = addImpl;
var ceilImplCPU = ceilImpl;
var expImplCPU = expImpl;
var expm1ImplCPU = expm1Impl;
var floorImplCPU = floorImpl;
var logImplCPU = logImpl;
var maxImplCPU = maxImpl;
var multiplyImplCPU = multiplyImpl;
var rsqrtImplCPU = rsqrtImpl;
var sliceImplCPU = sliceImpl;
var subImplCPU = subImpl;
var transposeImplCPU = transposeImpl;
var uniqueImplCPU = uniqueImpl;
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AddNProgram = function() {
function AddNProgram2(outputShape, shapes) {
this.outputShape = [];
this.outputShape = outputShape;
this.variableNames = shapes.map(function(_, i) {
return "T" + i;
});
var snippets = [];
this.variableNames.forEach(function(variable) {
snippets.push("float v" + variable + " = get" + variable + "AtOutCoords();");
});
var operation = this.variableNames.map(function(variable) {
return "v" + variable;
}).join(" + ");
this.userCode = "\n void main() {\n " + snippets.join("\n ") + "\n\n float result = " + operation + ";\n setOutput(result);\n }\n ";
}
return AddNProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AddNPackedProgram = function() {
function AddNPackedProgram2(outputShape, shapes) {
this.outputShape = [];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = outputShape;
this.variableNames = shapes.map(function(_, i) {
return "T" + i;
});
var snippets = [];
this.variableNames.forEach(function(variable) {
snippets.push("vec4 v" + variable + " = get" + variable + "AtOutCoords();");
});
var operation = this.variableNames.map(function(variable) {
return "v" + variable;
}).join(" + ");
this.userCode = "\n void main() {\n " + snippets.join("\n ") + "\n\n vec4 result = " + operation + ";\n setOutput(result);\n }\n ";
}
return AddNPackedProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ArgMinMaxProgram = function() {
function ArgMinMaxProgram2(reduceInfo, op, firstPass) {
this.variableNames = ["A"];
var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, outSize = reduceInfo.outSize;
if (!firstPass) {
this.variableNames.push("bestIndicesA");
}
this.outputShape = [batchSize, outSize];
var compOp = op === "max" ? ">" : "<";
var indexSnippet = firstPass ? "inOffset + i;" : "round(getBestIndicesA(batch, inOffset + i));";
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < " + windowSize + "; i++) {\n int inIdx = " + indexSnippet + ";\n float candidate = getA(batch, inIdx);\n if (candidate " + compOp + " bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n ";
}
return ArgMinMaxProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getVecChannels(name, rank) {
return ["x", "y", "z", "w", "u", "v"].slice(0, rank).map(function(d) {
return name + "." + d;
});
}
function getChannels(name, rank) {
if (rank === 1) {
return [name];
}
return getVecChannels(name, rank);
}
function getSourceCoords(rank, dims) {
if (rank === 1) {
return "rc";
}
var coords2 = "";
for (var i = 0; i < rank; i++) {
coords2 += dims[i];
if (i < rank - 1) {
coords2 += ",";
}
}
return coords2;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getGlslDifferences() {
var version2;
var attribute;
var varyingVs;
var varyingFs;
var texture2D;
var output;
var defineOutput;
var defineSpecialNaN;
var defineSpecialInf;
var defineRound;
if (tf.env().getNumber("WEBGL_VERSION") === 2) {
version2 = "#version 300 es";
attribute = "in";
varyingVs = "out";
varyingFs = "in";
texture2D = "texture";
output = "outputColor";
defineOutput = "out vec4 outputColor;";
defineSpecialNaN = "\n bool isnan_custom(float val) {\n return (val > 0.0 || val < 0.0) ? false : val != 0.0;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n ";
defineSpecialInf = "";
defineRound = "\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
} else {
version2 = "";
attribute = "attribute";
varyingVs = "varying";
varyingFs = "varying";
texture2D = "texture2D";
output = "gl_FragColor";
defineOutput = "";
defineSpecialNaN = "\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n ";
defineSpecialInf = "\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n ";
defineRound = "\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
}
return {
version: version2,
attribute,
varyingVs,
varyingFs,
texture2D,
output,
defineOutput,
defineSpecialNaN,
defineSpecialInf,
defineRound
};
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getLogicalCoordinatesFromFlatIndex(coords2, shape, index) {
if (index === void 0) {
index = "index";
}
var strides = tf.util.computeStrides(shape);
return strides.map(function(stride, i) {
var line1 = "int " + coords2[i] + " = " + index + " / " + stride;
var line2 = i === strides.length - 1 ? "int " + coords2[i + 1] + " = " + index + " - " + coords2[i] + " * " + stride : "index -= " + coords2[i] + " * " + stride;
return line1 + "; " + line2 + ";";
}).join("");
}
function getFlatIndexFrom3D(shape) {
var strides = tf.util.computeStrides(shape).map(function(d) {
return d.toString();
});
return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * " + strides[0] + " + coords.y * " + strides[1] + " + coords.z;\n }\n";
}
var ENCODE_FLOAT_SNIPPET = "\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n";
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var getBroadcastDims = tf.backend_util.getBroadcastDims;
function makeShader(inputsInfo, outputShape, userCode, usesPackedTextures) {
var prefixSnippets = [];
inputsInfo.forEach(function(x) {
var size = tf.util.sizeFromShape(x.shapeInfo.logicalShape);
if (x.shapeInfo.isUniform) {
prefixSnippets.push("uniform float " + x.name + (size > 1 ? "[" + size + "]" : "") + ";");
} else {
prefixSnippets.push("uniform sampler2D " + x.name + ";");
prefixSnippets.push("uniform int offset" + x.name + ";");
}
});
var inputPrefixSnippet = prefixSnippets.join("\n");
var inputSamplingSnippet = inputsInfo.map(function(x) {
return getInputSamplingSnippet(x, outputShape, usesPackedTextures);
}).join("\n");
var outTexShape = outputShape.texShape;
var glsl = getGlslDifferences();
var floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
var outputSamplingSnippet;
var floatTextureSetOutputSnippet;
var shaderPrefix = getShaderPrefix(glsl);
if (outputShape.isPacked) {
outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
} else {
outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
}
if (usesPackedTextures) {
shaderPrefix += SHADER_PACKED_PREFIX;
}
var source = [
shaderPrefix,
floatTextureSampleSnippet,
floatTextureSetOutputSnippet,
inputPrefixSnippet,
outputSamplingSnippet,
inputSamplingSnippet,
userCode
].join("\n");
return source;
}
function getSamplerFromInInfo(inInfo) {
var shape = inInfo.shapeInfo.logicalShape;
switch (shape.length) {
case 0:
return getSamplerScalar(inInfo);
case 1:
return getSampler1D(inInfo);
case 2:
return getSampler2D(inInfo);
case 3:
return getSampler3D(inInfo);
case 4:
return getSampler4D(inInfo);
case 5:
return getSampler5D(inInfo);
case 6:
return getSampler6D(inInfo);
default:
throw new Error(shape.length + "-D input sampling is not yet supported");
}
}
function getPackedSamplerFromInInfo(inInfo) {
var shape = inInfo.shapeInfo.logicalShape;
switch (shape.length) {
case 0:
return getPackedSamplerScalar(inInfo);
case 1:
return getPackedSampler1D(inInfo);
case 2:
return getPackedSampler2D(inInfo);
case 3:
return getPackedSampler3D(inInfo);
default:
return getPackedSamplerND(inInfo);
}
}
function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures) {
if (usesPackedTextures === void 0) {
usesPackedTextures = false;
}
var res = "";
if (usesPackedTextures) {
res += getPackedSamplerFromInInfo(inInfo);
} else {
res += getSamplerFromInInfo(inInfo);
}
var inShape = inInfo.shapeInfo.logicalShape;
var outShape = outShapeInfo.logicalShape;
if (inShape.length <= outShape.length) {
if (usesPackedTextures) {
res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
} else {
res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
}
}
return res;
}
function getPackedOutputSamplingSnippet(outShape, outTexShape) {
switch (outShape.length) {
case 0:
return getOutputScalarCoords();
case 1:
return getOutputPacked1DCoords(outShape, outTexShape);
case 2:
return getOutputPacked2DCoords(outShape, outTexShape);
case 3:
return getOutputPacked3DCoords(outShape, outTexShape);
default:
return getOutputPackedNDCoords(outShape, outTexShape);
}
}
function getOutputSamplingSnippet(outShape, outTexShape) {
switch (outShape.length) {
case 0:
return getOutputScalarCoords();
case 1:
return getOutput1DCoords(outShape, outTexShape);
case 2:
return getOutput2DCoords(outShape, outTexShape);
case 3:
return getOutput3DCoords(outShape, outTexShape);
case 4:
return getOutput4DCoords(outShape, outTexShape);
case 5:
return getOutput5DCoords(outShape, outTexShape);
case 6:
return getOutput6DCoords(outShape, outTexShape);
default:
throw new Error(outShape.length + "-D output sampling is not yet supported");
}
}
function getFloatTextureSampleSnippet(glsl) {
return "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return " + glsl.texture2D + "(textureSampler, uv).r;\n }\n ";
}
function getFloatTextureSetRSnippet(glsl) {
return "\n void setOutput(float val) {\n " + glsl.output + " = vec4(val, 0, 0, 0);\n }\n ";
}
function getFloatTextureSetRGBASnippet(glsl) {
return "\n void setOutput(vec4 val) {\n " + glsl.output + " = val;\n }\n ";
}
function getShaderPrefix(glsl) {
var SHADER_PREFIX = glsl.version + "\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n " + glsl.varyingFs + " vec2 resultUV;\n " + glsl.defineOutput + "\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n " + glsl.defineSpecialNaN + "\n " + glsl.defineSpecialInf + "\n " + glsl.defineRound + "\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n " + SAMPLE_1D_SNIPPET + "\n " + SAMPLE_2D_SNIPPET + "\n " + SAMPLE_3D_SNIPPET + "\n ";
return SHADER_PREFIX;
}
var SAMPLE_1D_SNIPPET = "\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_2D_SNIPPET = "\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_3D_SNIPPET = "\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SHADER_PACKED_PREFIX = "\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n";
function getOutputScalarCoords() {
return "\n int getOutputCoords() {\n return 0;\n }\n ";
}
function getOutputPacked1DCoords(shape, texShape) {
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (packedTexShape[0] === 1) {
return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * " + packedTexShape[1] + ".0);\n }\n ";
}
if (packedTexShape[1] === 1) {
return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * " + packedTexShape[0] + ".0);\n }\n ";
}
return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n return 2 * (resTexRC.x * " + packedTexShape[1] + " + resTexRC.y);\n }\n ";
}
function getOutput1DCoords(shape, texShape) {
if (texShape[0] === 1) {
return "\n int getOutputCoords() {\n return int(resultUV.x * " + texShape[1] + ".0);\n }\n ";
}
if (texShape[1] === 1) {
return "\n int getOutputCoords() {\n return int(resultUV.y * " + texShape[0] + ".0);\n }\n ";
}
return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n return resTexRC.x * " + texShape[1] + " + resTexRC.y;\n }\n ";
}
function getOutputPacked3DCoords(shape, texShape) {
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var texelsInLogicalRow = Math.ceil(shape[2] / 2);
var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec3(b, r, c);\n }\n ";
}
function getOutput3DCoords(shape, texShape) {
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(["r", "c", "d"], shape);
return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n ";
}
function getOutputPackedNDCoords(shape, texShape) {
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
var texelsInBatchN = texelsInBatch;
var batches = "";
var coords2 = "b, r, c";
for (var b = 2; b < shape.length - 1; b++) {
texelsInBatchN *= shape[shape.length - b - 1];
batches = "\n int b" + b + " = index / " + texelsInBatchN + ";\n index -= b" + b + " * " + texelsInBatchN + ";\n " + batches;
coords2 = "b" + b + ", " + coords2;
}
return "\n ivec" + shape.length + " getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n " + batches + "\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec" + shape.length + "(" + coords2 + ");\n }\n ";
}
function getOutput4DCoords(shape, texShape) {
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(["r", "c", "d", "d2"], shape);
return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec4(r, c, d, d2);\n }\n ";
}
function getOutput5DCoords(shape, texShape) {
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(["r", "c", "d", "d2", "d3"], shape);
return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(" + texShape[0] + ",\n " + texShape[1] + "));\n\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n ";
}
function getOutput6DCoords(shape, texShape) {
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(["r", "c", "d", "d2", "d3", "d4"], shape);
return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n ";
}
function getOutputPacked2DCoords(shape, texShape) {
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (tf.util.arraysEqual(shape, texShape)) {
return "\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n }\n ";
}
var texelsInLogicalRow = Math.ceil(shape[1] / 2);
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec2(r, c);\n }\n ";
}
function getOutput2DCoords(shape, texShape) {
if (tf.util.arraysEqual(shape, texShape)) {
return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(" + texShape[0] + ", " + texShape[1] + "));\n }\n ";
}
if (shape[1] === 1) {
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(index, 0);\n }\n ";
}
if (shape[0] === 1) {
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(0, index);\n }\n ";
}
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + shape[1] + ";\n int c = index - r * " + shape[1] + ";\n return ivec2(r, c);\n }\n ";
}
function getFlatOffsetUniformName(texName) {
return "offset" + texName;
}
function getPackedSamplerScalar(inputInfo) {
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var glsl = getGlslDifferences();
return "\n vec4 " + funcName + "() {\n return " + glsl.texture2D + "(" + texName + ", halfCR);\n }\n ";
}
function getSamplerScalar(inputInfo) {
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
if (inputInfo.shapeInfo.isUniform) {
return "float " + funcName + "() {return " + texName + ";}";
}
var _a = inputInfo.shapeInfo.texShape, texNumR = _a[0], texNumC = _a[1];
if (texNumR === 1 && texNumC === 1) {
return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
}
var _b = inputInfo.shapeInfo.texShape, tNumR = _b[0], tNumC = _b[1];
var offset = getFlatOffsetUniformName(texName);
return "\n float " + funcName + "() {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getPackedSampler1D(inputInfo) {
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var glsl = getGlslDifferences();
return "\n vec4 " + funcName + "(int index) {\n vec2 uv = packedUVfrom1D(\n " + packedTexShape[0] + ", " + packedTexShape[1] + ", index);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
function getSampler1D(inputInfo) {
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int index) {\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var tNumR = texShape[0];
var tNumC = texShape[1];
if (tNumC === 1 && tNumR === 1) {
return "\n float " + funcName + "(int index) {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
if (tNumC === 1) {
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index + " + offset + ") + 0.5) / " + tNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (tNumR === 1) {
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index + " + offset + ") + 0.5) / " + tNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int index) {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getPackedSampler2D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
var glsl = getGlslDifferences();
if (texShape != null && tf.util.arraysEqual(shape, texShape)) {
return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var valuesPerRow = Math.ceil(shape[1] / 2);
return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = packedUVfrom2D(" + valuesPerRow + ", " + packedTexShape[0] + ", " + packedTexShape[1] + ", row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
function getSampler2D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
if (texShape != null && tf.util.arraysEqual(shape, texShape)) {
var texNumR_1 = texShape[0];
var texNumC_1 = texShape[1];
return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC_1 + ".0, " + texNumR_1 + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
var squeezedShape = newShape;
if (squeezedShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
var params = ["row", "col"];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(" + shape[1] + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var texNumR = texShape[0];
var texNumC = texShape[1];
var offset = getFlatOffsetUniformName(texName);
if (texNumC === 1) {
return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumR === 1) {
return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2((index + 0.5) / " + texNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + shape[1] + " + col + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n";
}
function getPackedSampler3D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (shape[0] === 1) {
var squeezedShape = shape.slice(1);
var keptDims = [1, 2];
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
var params = ["b", "row", "col"];
return "\n " + getPackedSamplerFromInInfo(newInputInfo) + "\n vec4 " + funcName + "(int b, int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
var texNumR = packedTexShape[0];
var texNumC = packedTexShape[1];
var valuesPerRow = Math.ceil(shape[2] / 2);
var texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
var glsl = getGlslDifferences();
return "\n vec4 " + funcName + "(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n " + texNumR + ", " + texNumC + ", " + texelsInBatch + ", " + valuesPerRow + ", b, row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
function getSampler3D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride0 = shape[1] * shape[2];
var stride1 = shape[2];
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
var squeezedShape = newShape;
if (squeezedShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
var params = ["row", "col", "depth"];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(" + stride0 + ", " + stride1 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
var flatOffset = inputInfo.shapeInfo.flatOffset;
if (texNumC === stride0 && flatOffset == null) {
return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(" + stride1 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride1 && flatOffset == null) {
return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(" + shape[1] + ", 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
return "\n float " + funcName + "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getPackedSamplerND(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var rank = shape.length;
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var texNumR = packedTexShape[0];
var texNumC = packedTexShape[1];
var valuesPerRow = Math.ceil(shape[rank - 1] / 2);
var texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
var params = "int b, int row, int col";
var index = "b * " + texelsInBatch + " + (row / 2) * " + valuesPerRow + " + (col / 2)";
for (var b = 2; b < rank - 1; b++) {
params = "int b" + b + ", " + params;
texelsInBatch *= shape[rank - b - 1];
index = "b" + b + " * " + texelsInBatch + " + " + index;
}
var glsl = getGlslDifferences();
return "\n vec4 " + funcName + "(" + params + ") {\n int index = " + index + ";\n int texR = index / " + texNumC + ";\n int texC = index - texR * " + texNumC + ";\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ", " + texNumR + ");\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
function getSampler4D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride2 = shape[3];
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ["row", "col", "depth", "depth2"];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var flatOffset = inputInfo.shapeInfo.flatOffset;
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0 && flatOffset == null) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(" + stride1 + ", " + stride2 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride2 && flatOffset == null) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(" + shape[1] * shape[2] + ", " + shape[2] + ", 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2;\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getSampler5D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride3 = shape[4];
var stride2 = shape[3] * stride3;
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ["row", "col", "depth", "depth2", "depth3"];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n depth3;\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var flatOffset = inputInfo.shapeInfo.flatOffset;
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0 && flatOffset == null) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride3 && flatOffset == null) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] + ",\n " + shape[2] * shape[3] + ", " + shape[3] + ", 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getSampler6D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = "get" + texName.charAt(0).toUpperCase() + texName.slice(1);
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ["row", "col", "depth", "depth2", "depth3", "depth4"];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
var stride4 = shape[5];
var stride3 = shape[4] * stride4;
var stride2 = shape[3] * stride3;
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
if (inputInfo.shapeInfo.isUniform) {
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n dot(\n vec2(depth3, depth4),\n vec2(" + stride4 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var flatOffset = inputInfo.shapeInfo.flatOffset;
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0 && flatOffset == null) {
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", " + stride4 + ")) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride4 && flatOffset == null) {
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] * shape[4] + ",\n " + shape[2] * shape[3] * shape[4] + ",\n " + shape[3] * shape[4] + ",\n " + shape[4] + ")) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 * " + stride4 + " + depth4 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getUniformSampler(inputInfo) {
var texName = inputInfo.name;
var inSize = tf.util.sizeFromShape(inputInfo.shapeInfo.logicalShape);
if (inSize < 2) {
return "return " + texName + ";";
}
return "\n for (int i = 0; i < " + inSize + "; i++) {\n if (i == index) {\n return " + texName + "[i];\n }\n }\n ";
}
function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
var texName = inputInfo.name;
var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
var funcName = "get" + texFuncSnippet + "AtOutCoords";
var inRank = inputInfo.shapeInfo.logicalShape.length;
var outRank = outShapeInfo.logicalShape.length;
var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
var type = getCoordsDataType(outRank);
var rankDiff = outRank - inRank;
var coordsSnippet;
var fields = ["x", "y", "z", "w", "u", "v"];
if (inRank === 0) {
coordsSnippet = "";
} else if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = "coords = 0;";
} else {
coordsSnippet = broadcastDims.map(function(d) {
return "coords." + fields[d + rankDiff] + " = 0;";
}).join("\n");
}
var unpackedCoordsSnippet = "";
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = "coords";
} else {
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape.map(function(s, i) {
return "coords." + fields[i + rankDiff];
}).join(", ");
}
var output = "return outputValue;";
var inSize = tf.util.sizeFromShape(inputInfo.shapeInfo.logicalShape);
var isInputScalar = inSize === 1;
var outSize = tf.util.sizeFromShape(outShapeInfo.logicalShape);
var isOutputScalar = outSize === 1;
if (inRank === 1 && !isInputScalar && !isOutputScalar) {
output = "\n return vec4(outputValue.xy, outputValue.xy);\n ";
} else if (isInputScalar && !isOutputScalar) {
if (outRank === 1) {
output = "\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n ";
} else {
output = "\n return vec4(outputValue.x);\n ";
}
} else if (broadcastDims.length) {
var rows = inRank - 2;
var cols = inRank - 1;
if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
output = "return vec4(outputValue.x);";
} else if (broadcastDims.indexOf(rows) > -1) {
output = "return vec4(outputValue.x, outputValue.y, outputValue.x, outputValue.y);";
} else if (broadcastDims.indexOf(cols) > -1) {
output = "return vec4(outputValue.xx, outputValue.zz);";
}
}
return "\n vec4 " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n vec4 outputValue = get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n " + output + "\n }\n ";
}
function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
var texName = inputInfo.name;
var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
var funcName = "get" + texFuncSnippet + "AtOutCoords";
var outTexShape = outShapeInfo.texShape;
var inTexShape = inputInfo.shapeInfo.texShape;
var inRank = inputInfo.shapeInfo.logicalShape.length;
var outRank = outShapeInfo.logicalShape.length;
if (!inputInfo.shapeInfo.isUniform && inRank === outRank && inputInfo.shapeInfo.flatOffset == null && tf.util.arraysEqual(inTexShape, outTexShape)) {
return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", resultUV);\n }\n ";
}
var type = getCoordsDataType(outRank);
var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
var rankDiff = outRank - inRank;
var coordsSnippet;
var fields = ["x", "y", "z", "w", "u", "v"];
if (inRank === 0) {
coordsSnippet = "";
} else if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = "coords = 0;";
} else {
coordsSnippet = broadcastDims.map(function(d) {
return "coords." + fields[d + rankDiff] + " = 0;";
}).join("\n");
}
var unpackedCoordsSnippet = "";
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = "coords";
} else {
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape.map(function(s, i) {
return "coords." + fields[i + rankDiff];
}).join(", ");
}
return "\n float " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n return get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n }\n ";
}
function getCoordsDataType(rank) {
if (rank <= 1) {
return "int";
} else if (rank === 2) {
return "ivec2";
} else if (rank === 3) {
return "ivec3";
} else if (rank === 4) {
return "ivec4";
} else if (rank === 5) {
return "ivec5";
} else if (rank === 6) {
return "ivec6";
} else {
throw Error("GPU for rank " + rank + " is not yet supported");
}
}
function squeezeInputInfo(inInfo, squeezedShape) {
var newInputInfo = JSON.parse(JSON.stringify(inInfo));
newInputInfo.shapeInfo.logicalShape = squeezedShape;
return newInputInfo;
}
function getSqueezedParams(params, keptDims) {
return keptDims.map(function(d) {
return params[d];
}).join(", ");
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ArgMinMaxPackedProgram = function() {
function ArgMinMaxPackedProgram2(shape, windowSize, op, firstPass) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = true;
tf.util.assert(shape.length > 2, function() {
return "Packed arg" + (op.charAt(0).toUpperCase() + op.slice(1)) + " supports only inputs with rank above 2.";
});
var inSize = shape[shape.length - 1];
var outSize = Math.ceil(inSize / windowSize);
this.outputShape = shape.slice(0, -1);
if (outSize > 1) {
this.outputShape.push(outSize);
}
if (!firstPass) {
this.variableNames.push("bestIndicesA");
}
var outShape = this.outputShape;
var rank = outShape.length;
var dtype = getCoordsDataType(rank);
var coords2 = getChannels("coords", rank);
var sourceLocSetup;
var sourceRank;
if (outSize === 1) {
sourceRank = rank + 1;
var sourceLocDType = getCoordsDataType(sourceRank);
sourceLocSetup = "\n " + sourceLocDType + " sourceLocR = " + sourceLocDType + "(" + coords2.join() + ", 0);\n ++" + coords2[rank - 1] + ";\n " + sourceLocDType + " sourceLocG = " + sourceLocDType + "(" + coords2.join() + ", 0);\n ++" + coords2[rank - 2] + ";\n " + sourceLocDType + " sourceLocA = " + sourceLocDType + "(" + coords2.join() + ", 0);\n --" + coords2[rank - 1] + ";\n " + sourceLocDType + " sourceLocB = " + sourceLocDType + "(" + coords2.join() + ", 0);\n --" + coords2[rank - 2] + ";";
} else {
sourceRank = rank;
sourceLocSetup = "\n " + dtype + " sourceLocR = coords;\n ++" + coords2[rank - 1] + ";\n " + dtype + " sourceLocG = coords;\n ++" + coords2[rank - 2] + ";\n " + dtype + " sourceLocA = coords;\n --" + coords2[rank - 1] + ";\n " + dtype + " sourceLocB = coords;\n --" + coords2[rank - 2] + ";";
}
var channels = ["x", "y", "z", "w", "u", "v"].slice(0, sourceRank);
var inChannel = "." + channels[sourceRank - 1];
var intChannels = channels.map(function(x) {
return "int " + x;
});
var srcRCoords = getChannels("sourceLocR", sourceRank - 1).concat("inIdx.r");
var srcGCoords = getChannels("sourceLocG", sourceRank - 1).concat("inIdx.g");
var srcBCoords = getChannels("sourceLocB", sourceRank - 1).concat("inIdx.b");
var srcACoords = getChannels("sourceLocA", sourceRank - 1).concat("inIdx.a");
var compOp = op === "max" ? "greaterThan" : "lessThan";
var fetchCandidateIdx = firstPass ? "" : "\n inIdx = round(vec4(getBestIndicesAChannel(" + srcRCoords.join() + "),\n getBestIndicesAChannel(" + srcGCoords.join() + "),\n getBestIndicesAChannel(" + srcBCoords.join() + "),\n getBestIndicesAChannel(" + srcACoords.join() + ")));";
var fetchValue = "vec4(\n getAChannel(" + srcRCoords.join() + "),\n hasNextCol ? getAChannel(" + srcGCoords.join() + ") : 0.,\n hasNextRow ? getAChannel(" + srcBCoords.join() + ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(" + srcACoords.join() + ") : 0.)";
var getBestIndicesAChannelSnippet = firstPass ? "" : "\n float getBestIndicesAChannel(" + intChannels.join() + ") {\n return getChannel(getBestIndicesA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }";
this.userCode = "\n float getAChannel(" + intChannels.join() + ") {\n return getChannel(getA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }\n " + getBestIndicesAChannelSnippet + "\n void main() {\n " + dtype + " coords = getOutputCoords();\n bool hasNextCol = " + coords2[rank - 1] + " < " + (outShape[rank - 1] - 1) + ";\n bool hasNextRow = " + coords2[rank - 2] + " < " + (outShape[rank - 2] - 1) + ";\n " + sourceLocSetup + "\n ivec4 srcIdx = ivec4(sourceLocR" + inChannel + ", sourceLocG" + inChannel + ",\n sourceLocB" + inChannel + ", sourceLocA" + inChannel + ") * " + windowSize + ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = " + fetchValue + ";\n\n for (int i = 0; i < " + windowSize + "; i++) {\n inIdx = srcIdx;\n " + fetchCandidateIdx + "\n vec4 candidate = " + fetchValue + ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(" + compOp + "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n ";
}
return ArgMinMaxPackedProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AvgPool2DBackpropProgram = function() {
function AvgPool2DBackpropProgram2(convInfo) {
this.variableNames = ["dy"];
this.outputShape = convInfo.inShape;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var avgMultiplier = 1 / (filterHeight * filterWidth);
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC+= " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n ";
}
return AvgPool2DBackpropProgram2;
}();
var AvgPool3DBackpropProgram = function() {
function AvgPool3DBackpropProgram2(convInfo) {
this.variableNames = ["dy"];
this.outputShape = convInfo.inShape;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return AvgPool3DBackpropProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CHECK_NAN_SNIPPET = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n";
var INT_DIV = "\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n";
var POW = "\nif(a < 0.0 && floor(b) < b){\n return NAN;\n}\nif (b == 0.0) {\n return 1.0;\n}\nreturn (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n";
var EQUAL = "return float(a == b);";
var LESS = "return float(a < b);";
var LESS_EQUAL = "return float(a <= b);";
var GREATER = "return float(a > b);";
var GREATER_EQUAL = "return float(a >= b);";
var LOGICAL_AND = "return float(a >= 1.0 && b >= 1.0);";
var LOGICAL_OR = "return float(a >= 1.0 || b >= 1.0);";
var MAX = CHECK_NAN_SNIPPET + "\n return max(a, b);\n";
var MIN = CHECK_NAN_SNIPPET + "\n return min(a, b);\n";
var MOD = "if (b == 0.0) return NAN;\n return mod(a, b);";
var ELU_DER = "return (b >= 1.0) ? a : a * (b + 1.0);";
var PRELU = "return (a < 0.) ? b * a : a;";
var BinaryOpProgram = function() {
function BinaryOpProgram2(op, aShape, bShape) {
this.variableNames = ["A", "B"];
this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
this.userCode = "\n float binaryOperation(float a, float b) {\n " + op + "\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n ";
}
return BinaryOpProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CHECK_NAN_SNIPPET$1 = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n";
var INT_DIV$1 = "\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n";
var POW$1 = "\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));\n " + CHECK_NAN_SNIPPET$1 + "\n return result;\n";
var PRELU$1 = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
var ELU_DER$1 = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n";
var EQUAL$1 = "\n return vec4(equal(a, b));\n";
var LESS$1 = "\n return vec4(lessThan(a, b));\n";
var LESS_EQUAL$1 = "\n return vec4(lessThanEqual(a, b));\n";
var GREATER$1 = "\n return vec4(greaterThan(a, b));\n";
var GREATER_EQUAL$1 = "\n return vec4(greaterThanEqual(a, b));\n";
var LOGICAL_AND$1 = "\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n";
var LOGICAL_OR$1 = "\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n";
var MAX$1 = "\n vec4 result = vec4(max(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + CHECK_NAN_SNIPPET$1 + "\n return result;\n";
var MIN$1 = "\n vec4 result = vec4(min(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + CHECK_NAN_SNIPPET$1 + "\n return result;\n";
var MOD$1 = "\n vec4 result = mod(a, b);\n vec4 isNaN = vec4(equal(b, vec4(0.0)));\n " + CHECK_NAN_SNIPPET$1 + "\n return result;\n";
var BinaryOpPackedProgram = function() {
function BinaryOpPackedProgram2(op, aShape, bShape, checkOutOfBounds) {
if (checkOutOfBounds === void 0) {
checkOutOfBounds = false;
}
this.variableNames = ["A", "B"];
this.supportsBroadcasting = true;
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
var rank = this.outputShape.length;
var checkOutOfBoundsString = "";
if (checkOutOfBounds) {
if (rank === 0 || tf.util.sizeFromShape(this.outputShape) === 1) {
checkOutOfBoundsString = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n ";
} else {
var dtype = getCoordsDataType(rank);
checkOutOfBoundsString = "\n " + dtype + " coords = getOutputCoords();\n ";
if (rank === 1) {
checkOutOfBoundsString += "\n result.y = (coords + 1) >= " + this.outputShape[0] + " ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n ";
} else {
var channels = getChannels("coords", rank);
checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (" + channels[rank - 2] + " + 1) >= " + this.outputShape[rank - 2] + ";\n bool nextColOutOfBounds =\n (" + channels[rank - 1] + " + 1) >= " + this.outputShape[rank - 1] + ";\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n ";
}
}
}
this.userCode = "\n vec4 binaryOperation(vec4 a, vec4 b) {\n " + op + "\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n " + checkOutOfBoundsString + "\n\n setOutput(result);\n }\n ";
}
return BinaryOpPackedProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ClipProgram = function() {
function ClipProgram2(aShape) {
this.variableNames = ["A"];
this.outputShape = aShape;
this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n ";
}
ClipProgram2.prototype.getCustomSetupFunc = function(min, max) {
var _this = this;
return function(gpgpu, webGLProgram) {
if (_this.minLoc == null) {
_this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, "minVal");
_this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, "maxVal");
}
gpgpu.gl.uniform1f(_this.minLoc, min);
gpgpu.gl.uniform1f(_this.maxLoc, max);
};
};
return ClipProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ClipPackedProgram = function() {
function ClipPackedProgram2(aShape) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = aShape;
this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n ";
}
ClipPackedProgram2.prototype.getCustomSetupFunc = function(min, max) {
var _this = this;
return function(gpgpu, webGLProgram) {
if (_this.minLoc == null) {
_this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, "minVal");
_this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, "maxVal");
}
gpgpu.gl.uniform1f(_this.minLoc, min);
gpgpu.gl.uniform1f(_this.maxLoc, max);
};
};
return ClipPackedProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ComplexAbsProgram = function() {
function ComplexAbsProgram2(shape) {
this.variableNames = ["real", "imag"];
this.outputShape = shape;
this.userCode = "\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n ";
}
return ComplexAbsProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Conv2DDerFilterProgram = function() {
function Conv2DDerFilterProgram2(convInfo) {
this.variableNames = ["x", "dy"];
this.outputShape = convInfo.filterShape;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var isChannelsLast = convInfo.dataFormat === "channelsLast";
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n if (" + isChannelsLast + ") {\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n } else {\n float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return Conv2DDerFilterProgram2;
}();
var Conv2DDerInputProgram = function() {
function Conv2DDerInputProgram2(convInfo) {
this.variableNames = ["dy", "W"];
this.outputShape = convInfo.inShape;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var isChannelsLast = convInfo.dataFormat === "channelsLast";
var padTop = filterHeight - 1 - convInfo.padInfo.top;
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
var rowDim = isChannelsLast ? 1 : 2;
var colDim = isChannelsLast ? 2 : 3;
var channelDim = isChannelsLast ? 3 : 1;
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[" + channelDim + "];\n\n ivec2 dyCorner = ivec2(coords[" + rowDim + "], coords[" + colDim + "]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n\n if (" + isChannelsLast + ") {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return Conv2DDerInputProgram2;
}();
var Conv3DDerFilterProgram = function() {
function Conv3DDerFilterProgram2(convInfo) {
this.variableNames = ["x", "dy"];
this.outputShape = convInfo.filterShape;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
this.userCode = "\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yF = 0; yF < " + convInfo.outDepth + "; yF++) {\n int xF = wF + yF * " + strideDepth + " - " + padFront + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return Conv3DDerFilterProgram2;
}();
var Conv3DDerInputProgram = function() {
function Conv3DDerInputProgram2(convInfo) {
this.variableNames = ["dy", "W"];
this.outputShape = convInfo.inShape;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padFront = filterDepth - 1 - convInfo.padInfo.front;
var padTop = filterHeight - 1 - convInfo.padInfo.top;
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n float dyF = float(dyFCorner + wF) / " + strideDepth + ".0;\n\n if (dyF < 0.0 || dyF >= " + convInfo.outDepth + ".0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = " + filterDepth + " - 1 - wF;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return Conv3DDerInputProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DepthwiseConv2DDerFilterProgram = function() {
function DepthwiseConv2DDerFilterProgram2(convInfo) {
this.variableNames = ["x", "dy"];
this.outputShape = convInfo.filterShape;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var channelMul = convInfo.outChannels / convInfo.inChannels;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * " + channelMul + " + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return DepthwiseConv2DDerFilterProgram2;
}();
var DepthwiseConv2DDerInputProgram = function() {
function DepthwiseConv2DDerInputProgram2(convInfo) {
this.variableNames = ["dy", "W"];
this.outputShape = convInfo.inShape;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padTop = filterHeight - 1 - convInfo.padInfo.top;
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
var channelMul = convInfo.outChannels / convInfo.inChannels;
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < " + channelMul + "; dm++) {\n int d2 = d1 * " + channelMul + " + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return DepthwiseConv2DDerInputProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Conv2DProgram = function() {
function Conv2DProgram2(convInfo, addBias, activation, hasPreluActivationWeights) {
if (addBias === void 0) {
addBias = false;
}
if (activation === void 0) {
activation = null;
}
if (hasPreluActivationWeights === void 0) {
hasPreluActivationWeights = false;
}
this.variableNames = ["x", "W"];
this.outputShape = convInfo.outShape;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
var inputDepthVec4Remainder = convInfo.inChannels % 4;
var isChannelsLast = convInfo.dataFormat === "channelsLast";
var rowDim = isChannelsLast ? 1 : 2;
var colDim = isChannelsLast ? 2 : 3;
var channelDim = isChannelsLast ? 3 : 1;
var activationSnippet = "", applyActivationSnippet = "";
if (activation) {
if (hasPreluActivationWeights) {
activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
} else {
activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
}
applyActivationSnippet = "result = activation(result);";
}
var addBiasSnippet = addBias ? "result += getBiasAtOutCoords();" : "";
if (addBias) {
this.variableNames.push("bias");
}
if (hasPreluActivationWeights) {
this.variableNames.push("preluActivationWeights");
}
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[" + channelDim + "];\n\n ivec2 xRCCorner =\n ivec2(coords[" + rowDim + "], coords[" + colDim + "]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n\n if (" + isChannelsLast + ") {\n dotProd +=\n getX(batch, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else {\n dotProd +=\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC) *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n }\n\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 wValues = vec2(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 wValues = vec3(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
}
return Conv2DProgram2;
}();
var Conv3DProgram = function() {
function Conv3DProgram2(convInfo) {
this.variableNames = ["x", "W"];
this.outputShape = convInfo.outShape;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
var inputDepthVec4Remainder = convInfo.inChannels % 4;
this.userCode = "\n const ivec3 strides = ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n int xF = xFCorner + wF * " + dilationDepth + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n dotProd +=\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return Conv3DProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DepthwiseConv2DProgram = function() {
function DepthwiseConv2DProgram2(convInfo, addBias, activation, hasPreluActivation) {
if (addBias === void 0) {
addBias = false;
}
if (activation === void 0) {
activation = null;
}
if (hasPreluActivation === void 0) {
hasPreluActivation = false;
}
this.variableNames = ["x", "W"];
this.outputShape = convInfo.outShape;
var xNumRows = convInfo.inHeight;
var xNumCols = convInfo.inWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var channelMul = convInfo.outChannels / convInfo.inChannels;
var activationSnippet = "", applyActivationSnippet = "";
if (activation) {
if (hasPreluActivation) {
activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
} else {
activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
}
applyActivationSnippet = "result = activation(result);";
}
var addBiasSnippet = addBias ? "result += getBiasAtOutCoords();" : "";
if (addBias) {
this.variableNames.push("bias");
}
if (hasPreluActivation) {
this.variableNames.push("preluActivationWeights");
}
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + xNumRows + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + xNumCols + ") {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
}
return DepthwiseConv2DProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DepthwiseConvPacked2DProgram = function() {
function DepthwiseConvPacked2DProgram2(convInfo, addBias, activation, hasPreluActivation) {
if (addBias === void 0) {
addBias = false;
}
if (activation === void 0) {
activation = null;
}
if (hasPreluActivation === void 0) {
hasPreluActivation = false;
}
this.variableNames = ["x", "W"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = convInfo.outShape;
var xNumRows = convInfo.inHeight;
var xNumCols = convInfo.inWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var texelsAcross = filterWidth;
var mainLoop = "int xR; int xC; int xCOffset;";
for (var r = 0; r < filterHeight; r++) {
for (var c = 0; c < filterWidth; c++) {
mainLoop += "\n vec4 xTexelR" + r + "C" + c * 2 + " = vec4(0.);\n vec4 wR" + r + "C" + c + " = vec4(0.);\n vec4 xR" + r + "C" + c + " = vec4(0.);";
}
}
for (var r = 0; r < filterHeight; r++) {
for (var texelC = 0; texelC < texelsAcross; texelC++) {
var c = texelC * 2;
mainLoop += "\n xR = xRCorner + " + r * dilationHeight + ";\n xC = xCCorner + " + c * dilationWidth + ";\n ";
if (strideWidth === 1) {
if (c < filterWidth) {
if (padLeft % 2 === 1) {
mainLoop += "\n xCOffset = xC + 1;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= " + xNumCols + ") {\n xTexelR" + r + "C" + c + ".zw = vec2(0.);\n }\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + 1 - 2;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n vec4 previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= " + xNumCols + ") {\n previous.zw = vec2(0.);\n }\n\n xR" + r + "C" + c + " = vec4(previous.zw, xTexelR" + r + "C" + c + ".xy);\n } else {\n xR" + r + "C" + c + " = vec4(0, 0, xTexelR" + r + "C" + c + ".xy);\n }\n ";
} else {
mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + " && xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = xTexelR" + r + "C" + c + ";\n ";
}
if (c + 1 < filterWidth) {
var nextTexelOffset = padLeft % 2 === 0 ? tf.util.nearestLargerEven(dilationWidth) : dilationWidth;
if (dilationWidth % 2 === 0 && padLeft % 2 === 1 || dilationWidth % 2 !== 0 && padLeft % 2 !== 1) {
mainLoop += "\n xCOffset = xC + " + padLeft % 2 + " + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n ";
if (dilationWidth > 1) {
mainLoop += "\n xCOffset -= 2;\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n ";
}
mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".xy);\n ";
} else {
mainLoop += "\n xCOffset = xC + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n\n xR" + r + "C" + (c + 1) + " = xTexelR" + r + "C" + (c + 2) + ";\n ";
}
}
}
} else {
if (c < filterWidth) {
mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + ") {\n ";
if (padLeft % 2 === 1) {
mainLoop += "\n xCOffset = xC + 1 - " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n if(xC + 1 >= 0 && xC + 1 < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xC + 1, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n ";
if (c + 1 < filterWidth) {
mainLoop += "\n vec4 final = vec4(0.);\n xCOffset = xC + 1 + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n final = getX(batch, xR, xCOffset, d1);\n }\n xR" + r + "C" + (c + 1) + " = vec4(xTexelR" + r + "C" + (c + 2) + ".xy, final.xy);\n ";
}
} else {
mainLoop += "\n if(xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".xy, xTexelR" + r + "C" + (c + 2) + ".xy);\n ";
if (c + 1 < filterWidth) {
mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n ";
}
}
mainLoop += "}";
}
}
if (c < filterWidth) {
mainLoop += "\n vec4 wTexelR" + r + "C" + c + " = getW(" + r + ", " + c + ", d1, q);\n wR" + r + "C" + c + " = vec4(wTexelR" + r + "C" + c + ".xz, wTexelR" + r + "C" + c + ".xz);\n ";
if (c + 1 < filterWidth) {
mainLoop += "\n vec4 wTexelR" + r + "C" + (c + 1) + " = getW(" + r + ", " + (c + 1) + ", d1, q);\n wR" + r + "C" + (c + 1) + " =\n vec4(wTexelR" + r + "C" + (c + 1) + ".xz, wTexelR" + r + "C" + (c + 1) + ".xz);";
}
}
}
}
for (var r = 0; r < filterHeight; r++) {
for (var c = 0; c < filterWidth; c++) {
mainLoop += "dotProd += xR" + r + "C" + c + " * wR" + r + "C" + c + ";";
}
}
var activationSnippet = "", applyActivationSnippet = "";
if (activation) {
if (hasPreluActivation) {
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
} else {
activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
}
applyActivationSnippet = "result = activation(result);";
}
var addBiasSnippet = addBias ? "result += getBiasAtOutCoords();" : "";
if (addBias) {
this.variableNames.push("bias");
}
if (hasPreluActivation) {
this.variableNames.push("preluActivationWeights");
}
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2;\n int q = 0;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n vec4 dotProd = vec4(0.);\n\n " + mainLoop + "\n\n vec4 result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
}
return DepthwiseConvPacked2DProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CropAndResizeProgram = function() {
function CropAndResizeProgram2(imageShape, boxShape, cropSize, method, extrapolationValue) {
this.variableNames = ["Image", "Boxes", "BoxInd"];
this.outputShape = [];
var batch = imageShape[0], imageHeight = imageShape[1], imageWidth = imageShape[2], depth = imageShape[3];
var numBoxes = boxShape[0];
var cropHeight = cropSize[0], cropWidth = cropSize[1];
this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
var methodId = method === "bilinear" ? 1 : 0;
var _a = [imageHeight - 1 + ".0", imageWidth - 1 + ".0"], inputHeightFloat = _a[0], inputWidthFloat = _a[1];
var _b = cropHeight > 1 ? [
"" + (imageHeight - 1) / (cropHeight - 1),
"(y2-y1) * height_ratio",
"y1*" + inputHeightFloat + " + float(y)*(height_scale)"
] : [
"0.0",
"0.0",
"0.5 * (y1+y2) * " + inputHeightFloat
], heightRatio = _b[0], heightScale = _b[1], inY = _b[2];
var _c = cropWidth > 1 ? [
"" + (imageWidth - 1) / (cropWidth - 1),
"(x2-x1) * width_ratio",
"x1*" + inputWidthFloat + " + float(x)*(width_scale)"
] : [
"0.0",
"0.0",
"0.5 * (x1+x2) * " + inputWidthFloat
], widthRatio = _c[0], widthScale = _c[1], inX = _c[2];
this.userCode = "\n const float height_ratio = float(" + heightRatio + ");\n const float width_ratio = float(" + widthRatio + ");\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= " + batch + ") {\n return;\n }\n\n float height_scale = " + heightScale + ";\n float width_scale = " + widthScale + ";\n\n float in_y = " + inY + ";\n if( in_y < 0.0 || in_y > " + inputHeightFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n float in_x = " + inX + ";\n if( in_x < 0.0 || in_x > " + inputWidthFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(" + methodId + " == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n ";
}
return CropAndResizeProgram2;
}();
var CumSumProgram = function() {
function CumSumProgram2(shape, exclusive, reverse) {
this.variableNames = ["x"];
this.outputShape = shape;
var rank = shape.length;
var val = exclusive ? "0.0" : "getX(" + getCoords(rank, "coords") + ")";
var length = shape[shape.length - 1];
var condition = "";
var idxString = "";
if (exclusive) {
condition = reverse ? "end != " + (length - 1) : "end != 0";
idxString = reverse ? "end + 1" : "end - 1";
} else {
condition = reverse ? "end + pow2 < " + length : "end >= pow2";
idxString = reverse ? "end + pow2" : "end - pow2";
}
this.userCode = "\n uniform float index;\n void main() {\n " + getCoordsDataType(rank) + " coords = getOutputCoords();\n int end = " + getFinalCoord(rank, "coords") + ";\n float val = " + val + ";\n int pow2 = int(pow(2.0, index));\n if (" + condition + ") {\n int idx = " + idxString + ";\n " + getFinalCoord(rank, "coords") + " = idx;\n val += getX(" + getCoords(rank, "coords") + ");\n }\n setOutput(val);\n }\n ";
}
CumSumProgram2.prototype.getCustomSetupFunc = function(index) {
var _this = this;
return function(gpgpu, webGLProgram) {
if (_this.index == null) {
_this.index = gpgpu.getUniformLocation(webGLProgram, "index");
}
gpgpu.gl.uniform1f(_this.index, index);
};
};
return CumSumProgram2;
}();
function getCoords(rank, name) {
if (rank === 1) {
return "" + name;
} else if (rank === 2) {
return name + ".x, " + name + ".y";
} else if (rank === 3) {
return name + ".x, " + name + ".y, " + name + ".z";
} else if (rank === 4) {
return name + ".x, " + name + ".y, " + name + ".z, " + name + ".w";
} else {
throw Error("Cumulative sum for rank " + rank + " is not yet supported");
}
}
function getFinalCoord(rank, name) {
if (rank === 1) {
return "" + name;
} else if (rank === 2) {
return name + ".y";
} else if (rank === 3) {
return name + ".z";
} else if (rank === 4) {
return name + ".w";
} else {
throw Error("Cumulative sum for rank " + rank + " is not yet supported");
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DecodeMatrixProgram = function() {
function DecodeMatrixProgram2(outputShape) {
this.variableNames = ["A"];
this.packedInputs = false;
this.packedOutput = true;
this.outPackingScheme = PackingScheme.DENSE;
var texShape = getDenseTexShape(outputShape);
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(["r", "c", "d"], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n " + glsl.output + " = result;\n }\n ";
}
return DecodeMatrixProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DecodeMatrixPackedProgram = function() {
function DecodeMatrixPackedProgram2(outputShape) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = true;
this.outPackingScheme = PackingScheme.DENSE;
var texShape = getDenseTexShape(outputShape);
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(["r", "c", "d"], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n " + glsl.output + " = result;\n }\n ";
}
return DecodeMatrixPackedProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DepthToSpaceProgram = function() {
function DepthToSpaceProgram2(outputShape, blockSize, dataFormat) {
this.variableNames = ["x"];
this.outputShape = [];
this.outputShape = outputShape;
this.blockSize = blockSize;
this.dataFormat = dataFormat;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = " + this.getHeightCoordString() + ";\n int w = " + this.getWidthCoordString() + ";\n int d = " + this.getDepthCoordString() + ";\n\n int in_h = h / " + blockSize + ";\n int offset_h = imod(h, " + blockSize + ");\n int in_w = w / " + blockSize + ";\n int offset_w = imod(w, " + blockSize + ");\n int offset_d = (offset_h * " + blockSize + " + offset_w) *\n " + this.getOutputDepthSize() + ";\n int in_d = d + offset_d;\n\n float result = " + this.getInputSamplingString() + ";\n setOutput(result);\n }\n ";
}
DepthToSpaceProgram2.prototype.getHeightCoordString = function() {
if (this.dataFormat === "NHWC") {
return "coords[1]";
} else {
return "coords[2]";
}
};
DepthToSpaceProgram2.prototype.getWidthCoordString = function() {
if (this.dataFormat === "NHWC") {
return "coords[2]";
} else {
return "coords[3]";
}
};
DepthToSpaceProgram2.prototype.getDepthCoordString = function() {
if (this.dataFormat === "NHWC") {
return "coords[3]";
} else {
return "coords[1]";
}
};
DepthToSpaceProgram2.prototype.getOutputDepthSize = function() {
if (this.dataFormat === "NHWC") {
return this.outputShape[3];
} else {
return this.outputShape[1];
}
};
DepthToSpaceProgram2.prototype.getInputSamplingString = function() {
if (this.dataFormat === "NHWC") {
return "getX(b, in_h, in_w, in_d)";
} else {
return "getX(b, in_d, in_h, in_w)";
}
};
return DepthToSpaceProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DiagProgram = function() {
function DiagProgram2(size) {
this.variableNames = ["X"];
this.outputShape = [size, size];
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n ";
}
return DiagProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EncodeFloatProgram = function() {
function EncodeFloatProgram2(outputShape) {
this.variableNames = ["A"];
this.outTexUsage = TextureUsage.DOWNLOAD;
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n float x = getAAtOutCoords();\n " + glsl.output + " = encode_float(x);\n }\n ";
}
return EncodeFloatProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EncodeFloatPackedProgram = function() {
function EncodeFloatPackedProgram2(outputShape) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = false;
this.outTexUsage = TextureUsage.DOWNLOAD;
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n " + glsl.output + " = encode_float(x);\n }\n ";
}
return EncodeFloatPackedProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EncodeMatrixProgram = function() {
function EncodeMatrixProgram2(outputShape, texShape, inputIsUnsignedByte) {
if (inputIsUnsignedByte === void 0) {
inputIsUnsignedByte = false;
}
this.variableNames = ["A"];
var glsl = getGlslDifferences();
var height = texShape[0], width = texShape[1];
this.outputShape = outputShape;
var output = "result";
if (inputIsUnsignedByte) {
output = "floor(result * 255. + 0.5)";
}
this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n int flatIndex = getFlatIndex(coords);\n int offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n int r = flatIndex / " + width + ";\n int c = imod(flatIndex, " + width + ");\n vec2 uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n\n float result;\n\n if(offset == 0) {\n result = values[0];\n } else if(offset == 1) {\n result = values[1];\n } else if(offset == 2) {\n result = values[2];\n } else {\n result = values[3];\n }\n\n " + glsl.output + " = vec4(" + output + ", 0., 0., 0.);\n }\n ";
}
return EncodeMatrixProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EncodeMatrixPackedProgram = function() {
function EncodeMatrixPackedProgram2(outputShape, texShape, inputIsUnsignedByte) {
if (inputIsUnsignedByte === void 0) {
inputIsUnsignedByte = false;
}
this.variableNames = ["A"];
this.packedInputs = false;
this.packedOutput = true;
var glsl = getGlslDifferences();
var height = texShape[0], width = texShape[1];
this.outputShape = outputShape;
var mainLoop = "";
var output = "result";
if (inputIsUnsignedByte) {
output = "floor(result * 255. + 0.5)";
}
for (var row = 0; row <= 1; row++) {
for (var col = 0; col <= 1; col++) {
var channel = row * 2 + col;
mainLoop += "\n localCoords = coords;\n if(localCoords[2] + " + col + " < " + outputShape[2] + ") {\n localCoords[2] += " + col + ";\n if(localCoords[1] + " + row + " < " + outputShape[1] + ") {\n localCoords[1] += " + row + ";\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n r = flatIndex / " + width + ";\n c = imod(flatIndex, " + width + ");\n uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n values = " + glsl.texture2D + "(A, uv);\n\n if(offset == 0) {\n result[" + channel + "] = values[0];\n } else if(offset == 1) {\n result[" + channel + "] = values[1];\n } else if(offset == 2) {\n result[" + channel + "] = values[2];\n } else {\n result[" + channel + "] = values[3];\n }\n }\n }\n ";
}
}
this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n " + mainLoop + "\n\n " + glsl.output + " = " + output + ";\n }\n ";
}
return EncodeMatrixPackedProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FillProgram = function() {
function FillProgram2(shape, value) {
this.outputShape = [];
this.variableNames = ["x"];
this.outputShape = shape;
this.userCode = "\n uniform float value;\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n ";
}
FillProgram2.prototype.getCustomSetupFunc = function(value) {
var _this = this;
return function(gpgpu, webGLProgram) {
if (_this.valueLoc == null) {
_this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, "value");
}
gpgpu.gl.uniform1f(_this.valueLoc, value);
};
};
return FillProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var GatherProgram = function() {
function GatherProgram2(aShape, indicesLength, axis) {
this.variableNames = ["A", "indices"];
var outputShape = aShape.slice();
outputShape[axis] = indicesLength;
this.outputShape = outputShape;
this.rank = outputShape.length;
var dtype = getCoordsDataType(this.rank);
var sourceCoords = getSourceCoords$1(aShape, axis);
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n ";
}
return GatherProgram2;
}();
function getSourceCoords$1(aShape, axis) {
var rank = aShape.length;
if (rank > 4) {
throw Error("Gather for rank " + rank + " is not yet supported");
}
if (rank === 1) {
return "int(getIndices(resRC))";
}
var currentCoords = ["resRC.x", "resRC.y", "resRC.z", "resRC.w"];
var sourceCoords = [];
for (var i = 0; i < aShape.length; i++) {
if (i === axis) {
sourceCoords.push("int(getIndices(" + currentCoords[i] + "))");
} else {
sourceCoords.push("" + currentCoords[i]);
}
}
return sourceCoords.join();
}
var GatherNDProgram = function() {
function GatherNDProgram2(sliceDim, strides, shape) {
this.sliceDim = sliceDim;
this.strides = strides;
this.variableNames = ["x", "indices"];
this.outputShape = shape;
var stridesType = getCoordsDataType(strides.length);
var dtype = getCoordsDataType(shape.length);
var strideString = this.sliceDim > 1 ? "strides[j]" : "strides";
this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + this.strides + ");\n void main() {\n " + dtype + " coords = getOutputCoords();\n int flattenIndex = 0;\n for (int j = 0; j < " + this.sliceDim + "; j++) {\n int index = round(getIndices(coords[0], j));\n flattenIndex += index * " + strideString + ";\n }\n setOutput(getX(flattenIndex, coords[1]));\n }\n ";
}
return GatherNDProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createVertexShader$1(gl) {
var glsl = getGlslDifferences();
var vertexShaderSource = glsl.version + "\n precision highp float;\n " + glsl.attribute + " vec3 clipSpacePos;\n " + glsl.attribute + " vec2 uv;\n " + glsl.varyingVs + " vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }";
return createVertexShader(gl, vertexShaderSource);
}
function createVertexBuffer(gl) {
var vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
return createStaticVertexBuffer(gl, vertexArray);
}
function createIndexBuffer(gl) {
var triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
return createStaticIndexBuffer(gl, triangleVertexIndices);
}
function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
validateTextureSize(width, height);
var texture = createTexture(gl);
var tex2d = gl.TEXTURE_2D;
callAndCheck(gl, function() {
return gl.bindTexture(tex2d, texture);
});
callAndCheck(gl, function() {
return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
});
callAndCheck(gl, function() {
return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
});
callAndCheck(gl, function() {
return gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
});
callAndCheck(gl, function() {
return gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
});
callAndCheck(gl, function() {
return gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null);
});
callAndCheck(gl, function() {
return gl.bindTexture(gl.TEXTURE_2D, null);
});
return texture;
}
function getInternalFormatForFloat32MatrixTexture(textureConfig) {
return textureConfig.internalFormatFloat;
}
function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
}
function getInternalFormatForFloat16MatrixTexture(textureConfig) {
return textureConfig.internalFormatHalfFloat;
}
function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
}
function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
return textureConfig.downloadTextureFormat;
}
function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
}
function getInternalFormatForPackedMatrixTexture(textureConfig) {
return textureConfig.internalFormatPackedFloat;
}
function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
}
function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
return textureConfig.internalFormatPackedHalfFloat;
}
function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
}
function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
var posOffset = 0;
var uvOffset = 3 * 4;
var stride = 3 * 4 + 2 * 4;
callAndCheck(gl, function() {
return gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
});
var success = bindVertexBufferToProgramAttribute(gl, program, "clipSpacePos", vertexBuffer, 3, stride, posOffset);
return success && bindVertexBufferToProgramAttribute(gl, program, "uv", vertexBuffer, 2, stride, uvOffset);
}
function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
callAndCheck(gl, function() {
return gl.bindTexture(gl.TEXTURE_2D, texture);
});
var dataForUpload, texelDataType, internalFormat;
if (data instanceof Uint8Array) {
dataForUpload = new Uint8Array(width * height * 4);
texelDataType = gl.UNSIGNED_BYTE;
internalFormat = gl.RGBA;
} else {
dataForUpload = new Float32Array(width * height * 4);
texelDataType = gl.FLOAT;
internalFormat = textureConfig.internalFormatPackedFloat;
}
dataForUpload.set(data);
callAndCheck(gl, function() {
return gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload);
});
callAndCheck(gl, function() {
return gl.bindTexture(gl.TEXTURE_2D, null);
});
}
function uploadPixelDataToTexture(gl, texture, pixels) {
callAndCheck(gl, function() {
return gl.bindTexture(gl.TEXTURE_2D, texture);
});
if (pixels.data instanceof Uint8Array) {
callAndCheck(gl, function() {
return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data);
});
} else {
callAndCheck(gl, function() {
return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels);
});
}
callAndCheck(gl, function() {
return gl.bindTexture(gl.TEXTURE_2D, null);
});
}
function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
var buffer = gl2.createBuffer();
callAndCheck(gl2, function() {
return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
});
var bytesPerFloat = 4;
var valuesPerTexel = 4;
var bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
callAndCheck(gl2, function() {
return gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ);
});
callAndCheck(gl2, function() {
return gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0);
});
callAndCheck(gl2, function() {
return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
});
return buffer;
}
function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
var gl2 = gl;
var downloadTarget = new Float32Array(size);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
return downloadTarget;
}
function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1];
var numChannels = 4;
var downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
callAndCheck(gl, function() {
return gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget);
});
return new Float32Array(downloadTarget.buffer);
}
function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
var gl2 = gl;
var downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
return downloadTarget;
}
function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
var packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
callAndCheck(gl, function() {
return gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA);
});
return packedRGBA;
}
var gpgpu_util = {
__proto__: null,
createVertexShader: createVertexShader$1,
createVertexBuffer,
createIndexBuffer,
getInternalFormatForFloat32MatrixTexture,
createFloat32MatrixTexture,
getInternalFormatForFloat16MatrixTexture,
createFloat16MatrixTexture,
getInternalFormatForUnsignedBytesMatrixTexture,
createUnsignedBytesMatrixTexture,
getInternalFormatForPackedMatrixTexture,
createPackedMatrixTexture,
getInternalFormatForFloat16PackedMatrixTexture,
createFloat16PackedMatrixTexture,
bindVertexProgramAttributeStreams,
uploadDenseMatrixToTexture,
uploadPixelDataToTexture,
createBufferFromOutputTexture,
downloadFloat32MatrixFromBuffer,
downloadByteEncodedFloatMatrixFromOutputTexture,
downloadPackedMatrixFromBuffer,
downloadMatrixFromPackedOutputTexture
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var GPGPUContext = function() {
function GPGPUContext2(gl) {
this.outputTexture = null;
this.program = null;
this.disposed = false;
this.vertexAttrsAreBound = false;
this.itemsToPoll = [];
var glVersion = tf.env().getNumber("WEBGL_VERSION");
if (gl != null) {
this.gl = gl;
setWebGLContext(glVersion, gl);
} else {
this.gl = getWebGLContext(glVersion);
}
var COLOR_BUFFER_FLOAT = "WEBGL_color_buffer_float";
var COLOR_BUFFER_HALF_FLOAT = "EXT_color_buffer_half_float";
if (tf.env().getNumber("WEBGL_VERSION") === 1) {
var TEXTURE_FLOAT = "OES_texture_float";
var TEXTURE_HALF_FLOAT = "OES_texture_half_float";
this.textureFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
} else if (tf.env().get("WEBGL_FORCE_F16_TEXTURES")) {
throw new Error("GL context does not support half float textures, yet the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.");
}
this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
} else if (tf.env().get("WEBGL_FORCE_F16_TEXTURES")) {
throw new Error("GL context does not support color renderable half floats, yet the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.");
}
} else {
COLOR_BUFFER_FLOAT = "EXT_color_buffer_float";
if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
} else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
this.colorBufferHalfFloatExtension = this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
} else {
throw new Error("GL context does not support color renderable floats");
}
}
this.vertexBuffer = createVertexBuffer(this.gl);
this.indexBuffer = createIndexBuffer(this.gl);
this.framebuffer = createFramebuffer(this.gl);
this.textureConfig = getTextureConfig(this.gl, this.textureHalfFloatExtension);
}
Object.defineProperty(GPGPUContext2.prototype, "debug", {
get: function() {
return tf.env().getBool("DEBUG");
},
enumerable: true,
configurable: true
});
GPGPUContext2.prototype.dispose = function() {
var _this = this;
if (this.disposed) {
return;
}
if (this.program != null) {
console.warn("Disposing a GPGPUContext that still has a bound WebGLProgram. This is probably a resource leak, delete the program with GPGPUContext.deleteProgram before disposing.");
}
if (this.outputTexture != null) {
console.warn("Disposing a GPGPUContext that still has a bound output matrix texture. This is probably a resource leak, delete the output matrix texture with GPGPUContext.deleteMatrixTexture before disposing.");
}
var gl = this.gl;
callAndCheck(gl, function() {
return gl.finish();
});
callAndCheck(gl, function() {
return gl.bindFramebuffer(gl.FRAMEBUFFER, null);
});
callAndCheck(gl, function() {
return gl.deleteFramebuffer(_this.framebuffer);
});
callAndCheck(gl, function() {
return gl.bindBuffer(gl.ARRAY_BUFFER, null);
});
callAndCheck(gl, function() {
return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null);
});
callAndCheck(gl, function() {
return gl.deleteBuffer(_this.indexBuffer);
});
this.disposed = true;
};
GPGPUContext2.prototype.createFloat32MatrixTexture = function(rows, columns) {
this.throwIfDisposed();
return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
};
GPGPUContext2.prototype.createFloat16MatrixTexture = function(rows, columns) {
this.throwIfDisposed();
return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
};
GPGPUContext2.prototype.createUnsignedBytesMatrixTexture = function(rows, columns) {
this.throwIfDisposed();
return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
};
GPGPUContext2.prototype.uploadPixelDataToTexture = function(texture, pixels) {
this.throwIfDisposed();
uploadPixelDataToTexture(this.gl, texture, pixels);
};
GPGPUContext2.prototype.uploadDenseMatrixToTexture = function(texture, width, height, data) {
this.throwIfDisposed();
uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
};
GPGPUContext2.prototype.createFloat16PackedMatrixTexture = function(rows, columns) {
this.throwIfDisposed();
return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
};
GPGPUContext2.prototype.createPackedMatrixTexture = function(rows, columns) {
this.throwIfDisposed();
return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
};
GPGPUContext2.prototype.deleteMatrixTexture = function(texture) {
var _this = this;
this.throwIfDisposed();
if (this.outputTexture === texture) {
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
this.outputTexture = null;
}
callAndCheck(this.gl, function() {
return _this.gl.deleteTexture(texture);
});
};
GPGPUContext2.prototype.downloadByteEncodedFloatMatrixFromOutputTexture = function(texture, rows, columns) {
var _this = this;
return this.downloadMatrixDriver(texture, function() {
return downloadByteEncodedFloatMatrixFromOutputTexture(_this.gl, rows, columns, _this.textureConfig);
});
};
GPGPUContext2.prototype.downloadPackedMatrixFromBuffer = function(buffer, batch, rows, columns, physicalRows, physicalCols) {
return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig);
};
GPGPUContext2.prototype.downloadFloat32MatrixFromBuffer = function(buffer, size) {
return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
};
GPGPUContext2.prototype.createBufferFromTexture = function(texture, rows, columns) {
this.bindTextureToFrameBuffer(texture);
var result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig);
this.unbindTextureToFrameBuffer();
return result;
};
GPGPUContext2.prototype.createAndWaitForFence = function() {
var fenceContext = this.createFence(this.gl);
return this.pollFence(fenceContext);
};
GPGPUContext2.prototype.createFence = function(gl) {
var _this = this;
var query;
var isFencePassed;
if (tf.env().getBool("WEBGL_FENCE_API_ENABLED")) {
var gl2_1 = gl;
var sync_1 = gl2_1.fenceSync(gl2_1.SYNC_GPU_COMMANDS_COMPLETE, 0);
gl.flush();
isFencePassed = function() {
var status = gl2_1.clientWaitSync(sync_1, 0, 0);
return status === gl2_1.ALREADY_SIGNALED || status === gl2_1.CONDITION_SATISFIED;
};
query = sync_1;
} else if (tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION") > 0) {
query = this.beginQuery();
this.endQuery();
isFencePassed = function() {
return _this.isQueryAvailable(query, tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION"));
};
} else {
isFencePassed = function() {
return true;
};
}
return {query, isFencePassed};
};
GPGPUContext2.prototype.downloadMatrixFromPackedTexture = function(texture, physicalRows, physicalCols) {
var _this = this;
return this.downloadMatrixDriver(texture, function() {
return downloadMatrixFromPackedOutputTexture(_this.gl, physicalRows, physicalCols);
});
};
GPGPUContext2.prototype.createProgram = function(fragmentShaderSource) {
this.throwIfDisposed();
var gl = this.gl;
var fragmentShader = createFragmentShader(gl, fragmentShaderSource);
var vertexShader = createVertexShader$1(gl);
var program = createProgram(gl);
callAndCheck(gl, function() {
return gl.attachShader(program, vertexShader);
});
callAndCheck(gl, function() {
return gl.attachShader(program, fragmentShader);
});
linkProgram(gl, program);
if (this.debug) {
validateProgram(gl, program);
}
if (!this.vertexAttrsAreBound) {
this.setProgram(program);
this.vertexAttrsAreBound = bindVertexProgramAttributeStreams(gl, this.program, this.vertexBuffer);
}
return program;
};
GPGPUContext2.prototype.deleteProgram = function(program) {
var _this = this;
this.throwIfDisposed();
if (program === this.program) {
this.program = null;
}
if (program != null) {
callAndCheck(this.gl, function() {
return _this.gl.deleteProgram(program);
});
}
};
GPGPUContext2.prototype.setProgram = function(program) {
var _this = this;
this.throwIfDisposed();
this.program = program;
if (this.program != null && this.debug) {
validateProgram(this.gl, this.program);
}
callAndCheck(this.gl, function() {
return _this.gl.useProgram(program);
});
};
GPGPUContext2.prototype.getUniformLocation = function(program, uniformName, shouldThrow) {
if (shouldThrow === void 0) {
shouldThrow = true;
}
this.throwIfDisposed();
if (shouldThrow) {
return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
} else {
return getProgramUniformLocation(this.gl, program, uniformName);
}
};
GPGPUContext2.prototype.getAttributeLocation = function(program, attribute) {
var _this = this;
this.throwIfDisposed();
return callAndCheck(this.gl, function() {
return _this.gl.getAttribLocation(program, attribute);
});
};
GPGPUContext2.prototype.getUniformLocationNoThrow = function(program, uniformName) {
this.throwIfDisposed();
return this.gl.getUniformLocation(program, uniformName);
};
GPGPUContext2.prototype.setInputMatrixTexture = function(inputMatrixTexture, uniformLocation, textureUnit) {
this.throwIfDisposed();
this.throwIfNoProgram();
bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
};
GPGPUContext2.prototype.setOutputMatrixTexture = function(outputMatrixTexture, rows, columns) {
this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
};
GPGPUContext2.prototype.setOutputPackedMatrixTexture = function(outputPackedMatrixTexture, rows, columns) {
this.throwIfDisposed();
var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
};
GPGPUContext2.prototype.setOutputMatrixWriteRegion = function(startRow, numRows, startColumn, numColumns) {
this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
};
GPGPUContext2.prototype.setOutputPackedMatrixWriteRegion = function(startRow, numRows, startColumn, numColumns) {
throw new Error("setOutputPackedMatrixWriteRegion not implemented.");
};
GPGPUContext2.prototype.debugValidate = function() {
if (this.program != null) {
validateProgram(this.gl, this.program);
}
validateFramebuffer(this.gl);
};
GPGPUContext2.prototype.executeProgram = function() {
this.throwIfDisposed();
this.throwIfNoProgram();
var gl = this.gl;
if (this.debug) {
this.debugValidate();
}
callAndCheck(gl, function() {
return gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0);
});
};
GPGPUContext2.prototype.blockUntilAllProgramsCompleted = function() {
var _this = this;
this.throwIfDisposed();
callAndCheck(this.gl, function() {
return _this.gl.finish();
});
};
GPGPUContext2.prototype.getQueryTimerExtension = function() {
if (this.disjointQueryTimerExtension == null) {
this.disjointQueryTimerExtension = getExtensionOrThrow(this.gl, tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION") === 2 ? "EXT_disjoint_timer_query_webgl2" : "EXT_disjoint_timer_query");
}
return this.disjointQueryTimerExtension;
};
GPGPUContext2.prototype.getQueryTimerExtensionWebGL2 = function() {
return this.getQueryTimerExtension();
};
GPGPUContext2.prototype.getQueryTimerExtensionWebGL1 = function() {
return this.getQueryTimerExtension();
};
GPGPUContext2.prototype.beginQuery = function() {
if (tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION") === 2) {
var gl2 = this.gl;
var ext_1 = this.getQueryTimerExtensionWebGL2();
var query_1 = gl2.createQuery();
gl2.beginQuery(ext_1.TIME_ELAPSED_EXT, query_1);
return query_1;
}
var ext = this.getQueryTimerExtensionWebGL1();
var query = ext.createQueryEXT();
ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
return query;
};
GPGPUContext2.prototype.endQuery = function() {
if (tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION") === 2) {
var gl2 = this.gl;
var ext_2 = this.getQueryTimerExtensionWebGL2();
gl2.endQuery(ext_2.TIME_ELAPSED_EXT);
return;
}
var ext = this.getQueryTimerExtensionWebGL1();
ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
};
GPGPUContext2.prototype.waitForQueryAndGetTime = function(query) {
return __awaiter(this, void 0, void 0, function() {
var _this = this;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
return [4, tf.util.repeatedTry(function() {
return _this.disposed || _this.isQueryAvailable(query, tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION"));
})];
case 1:
_a.sent();
return [2, this.getQueryTime(query, tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION"))];
}
});
});
};
GPGPUContext2.prototype.getQueryTime = function(query, queryTimerVersion) {
if (queryTimerVersion === 0) {
return null;
}
if (queryTimerVersion === 2) {
var gl2 = this.gl;
var timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
return timeElapsedNanos / 1e6;
} else {
var ext = this.getQueryTimerExtensionWebGL1();
var timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
return timeElapsedNanos / 1e6;
}
};
GPGPUContext2.prototype.isQueryAvailable = function(query, queryTimerVersion) {
if (queryTimerVersion === 0) {
return true;
}
if (queryTimerVersion === 2) {
var gl2 = this.gl;
var ext = this.getQueryTimerExtensionWebGL2();
var available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
if (this.disjoint == null) {
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
}
return available && !this.disjoint;
} else {
var ext = this.getQueryTimerExtensionWebGL1();
var available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);
if (this.disjoint == null) {
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
}
return available && !this.disjoint;
}
};
GPGPUContext2.prototype.pollFence = function(fenceContext) {
var _this = this;
return new Promise(function(resolve) {
_this.addItemToPoll(function() {
return fenceContext.isFencePassed();
}, function() {
return resolve();
});
});
};
GPGPUContext2.prototype.pollItems = function() {
var index = linearSearchLastTrue(this.itemsToPoll.map(function(x) {
return x.isDoneFn;
}));
for (var i = 0; i <= index; ++i) {
var resolveFn = this.itemsToPoll[i].resolveFn;
resolveFn();
}
this.itemsToPoll = this.itemsToPoll.slice(index + 1);
};
GPGPUContext2.prototype.addItemToPoll = function(isDoneFn, resolveFn) {
var _this = this;
this.itemsToPoll.push({isDoneFn, resolveFn});
if (this.itemsToPoll.length > 1) {
return;
}
tf.util.repeatedTry(function() {
_this.pollItems();
return _this.itemsToPoll.length === 0;
});
};
GPGPUContext2.prototype.bindTextureToFrameBuffer = function(texture) {
this.throwIfDisposed();
bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
if (this.debug) {
validateFramebuffer(this.gl);
}
};
GPGPUContext2.prototype.unbindTextureToFrameBuffer = function() {
if (this.outputTexture != null) {
bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
if (this.debug) {
validateFramebuffer(this.gl);
}
} else {
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
}
};
GPGPUContext2.prototype.downloadMatrixDriver = function(texture, downloadAndDecode) {
this.bindTextureToFrameBuffer(texture);
var result = downloadAndDecode();
this.unbindTextureToFrameBuffer();
return result;
};
GPGPUContext2.prototype.setOutputMatrixTextureDriver = function(outputMatrixTextureMaybePacked, width, height) {
this.throwIfDisposed();
var gl = this.gl;
bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
if (this.debug) {
validateFramebuffer(gl);
}
this.outputTexture = outputMatrixTextureMaybePacked;
callAndCheck(gl, function() {
return gl.viewport(0, 0, width, height);
});
callAndCheck(gl, function() {
return gl.scissor(0, 0, width, height);
});
};
GPGPUContext2.prototype.setOutputMatrixWriteRegionDriver = function(x, y, width, height) {
var _this = this;
this.throwIfDisposed();
callAndCheck(this.gl, function() {
return _this.gl.scissor(x, y, width, height);
});
};
GPGPUContext2.prototype.throwIfDisposed = function() {
if (this.disposed) {
throw new Error("Attempted to use disposed GPGPUContext.");
}
};
GPGPUContext2.prototype.throwIfNoProgram = function() {
if (this.program == null) {
throw new Error("No GPU program is currently set.");
}
};
return GPGPUContext2;
}();
function linearSearchLastTrue(arr) {
var i = 0;
for (; i < arr.length; ++i) {
var isDone = arr[i]();
if (!isDone) {
break;
}
}
return i - 1;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function compileProgram(gpgpu, program, inputs, output) {
var userCode = program.userCode;
var inputInfos = inputs.map(function(input, i2) {
var shapeInfo = {
logicalShape: input.shape,
texShape: input.isUniform ? null : input.texData.texShape,
isUniform: input.isUniform,
isPacked: input.isUniform ? false : input.texData.isPacked,
flatOffset: null
};
if (input.texData != null && input.texData.slice != null && input.texData.slice.flatOffset > 0) {
shapeInfo.flatOffset = input.texData.slice.flatOffset;
}
return {name: program.variableNames[i2], shapeInfo};
});
var inShapeInfos = inputInfos.map(function(x) {
return x.shapeInfo;
});
var outShapeInfo = {
logicalShape: output.shape,
texShape: output.texData.texShape,
isUniform: false,
isPacked: output.texData.isPacked,
flatOffset: null
};
var source = makeShader(inputInfos, outShapeInfo, userCode, program.packedInputs);
var webGLProgram = gpgpu.createProgram(source);
var infLoc = null;
var nanLoc = gpgpu.getUniformLocation(webGLProgram, "NAN", false);
if (tf.env().getNumber("WEBGL_VERSION") === 1) {
infLoc = gpgpu.getUniformLocation(webGLProgram, "INFINITY", false);
}
var uniformLocations = {};
for (var i = 0; i < program.variableNames.length; i++) {
var varName = program.variableNames[i];
var shouldThrow = false;
uniformLocations[varName] = gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow);
uniformLocations["offset" + varName] = gpgpu.getUniformLocation(webGLProgram, "offset" + varName, shouldThrow);
}
return {
program,
source,
webGLProgram,
uniformLocations,
inShapeInfos,
outShapeInfo,
infLoc,
nanLoc
};
}
function validateBinaryAndProgram(shapeInfos, inputs) {
if (shapeInfos.length !== inputs.length) {
throw Error("Binary was compiled with " + shapeInfos.length + " inputs, but " + ("was executed with " + inputs.length + " inputs"));
}
shapeInfos.forEach(function(s, i) {
var shapeA = s.logicalShape;
var input = inputs[i];
var shapeB = input.shape;
if (!tf.util.arraysEqual(shapeA, shapeB)) {
throw Error("Binary was compiled with different shapes than " + ("the current args. Shapes " + shapeA + " and " + shapeB + " must match"));
}
if (s.isUniform && input.isUniform) {
return;
}
var texShapeA = s.texShape;
var texShapeB = input.isUniform ? null : input.texData.texShape;
if (!tf.util.arraysEqual(texShapeA, texShapeB)) {
throw Error("Binary was compiled with different texture shapes than the" + (" current args. Shape " + texShapeA + " and " + texShapeB + " must match"));
}
});
}
function runProgram(gpgpu, binary, inputs, output, customSetup) {
validateBinaryAndProgram(binary.inShapeInfos, inputs);
validateBinaryAndProgram([binary.outShapeInfo], [output]);
var outTex = output.texData.texture;
var outTexShape = output.texData.texShape;
if (output.texData.isPacked) {
gpgpu.setOutputPackedMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
} else {
gpgpu.setOutputMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
}
gpgpu.setProgram(binary.webGLProgram);
if (tf.env().getNumber("WEBGL_VERSION") === 1) {
if (binary.infLoc !== null) {
gpgpu.gl.uniform1f(binary.infLoc, Infinity);
}
}
if (binary.nanLoc !== null) {
gpgpu.gl.uniform1f(binary.nanLoc, NaN);
}
inputs.forEach(function(input, i) {
var varName = binary.program.variableNames[i];
var varLoc = binary.uniformLocations[varName];
var varOffsetLoc = binary.uniformLocations["offset" + varName];
if (varLoc == null) {
return;
}
if (input.isUniform) {
if (tf.util.sizeFromShape(input.shape) < 2) {
gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
} else {
var vals = input.uniformValues;
if (!(vals instanceof Float32Array)) {
vals = new Float32Array(vals);
}
gpgpu.gl.uniform1fv(varLoc, vals);
}
return;
}
if (input.texData.slice != null && varOffsetLoc != null) {
gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
}
gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i);
});
if (customSetup != null) {
customSetup(gpgpu, binary.webGLProgram);
}
gpgpu.executeProgram();
}
function makeShaderKey(program, inputs, output) {
var keyInputs = "";
inputs.concat(output).forEach(function(x) {
var hasOffset = x.texData != null && x.texData.slice != null && x.texData.slice.flatOffset > 0;
var texShape = x.isUniform ? "uniform" : x.texData.texShape;
keyInputs += x.shape + "_" + texShape + "_" + hasOffset;
});
var keyUserCode = program.userCode;
var key = program.constructor.name;
key += "_" + keyInputs + "_" + keyUserCode;
return key;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Im2ColPackedProgram = function() {
function Im2ColPackedProgram2(outputShape, inputShape, convInfo) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = outputShape;
var filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, strideWidth = convInfo.strideWidth, strideHeight = convInfo.strideHeight, padInfo = convInfo.padInfo, outWidth = convInfo.outWidth, dilationWidth = convInfo.dilationWidth, dilationHeight = convInfo.dilationHeight, dataFormat = convInfo.dataFormat;
var left = padInfo.left, top = padInfo.top;
var itemsPerBlockRow = inChannels * filterWidth;
var glsl = getGlslDifferences();
var isChannelsLast = dataFormat === "channelsLast";
var rowDim = isChannelsLast ? 0 : 1;
var colDim = isChannelsLast ? 1 : 2;
var unrolled = "";
for (var row = 0; row <= 1; row++) {
for (var col = 0; col <= 1; col++) {
unrolled += "\n blockIndex = rc.y + " + col + ";\n pos = rc.x + " + row + ";\n\n if(blockIndex < " + outputShape[1] + " && pos < " + outputShape[0] + ") {\n offsetY = int(blockIndex / (" + outWidth + ")) * " + strideHeight + " - " + top + ";\n d0 = offsetY + " + dilationHeight + " * (pos / " + itemsPerBlockRow + ");\n\n if(d0 < " + inputShape[rowDim] + " && d0 >= 0) {\n\n offsetX = int(mod(float(blockIndex), " + outWidth + ".) * " + strideWidth + ". - " + left + ".);\n d1 = offsetX + " + dilationWidth + " * (int(mod(float(pos), " + itemsPerBlockRow + ".) / " + inChannels + ".));\n\n if(d1 < " + inputShape[colDim] + " && d1 >= 0) {\n\n ch = int(mod(float(pos), " + inChannels + ".));\n\n if (" + isChannelsLast + ") {\n innerDims = vec2(d1, ch);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n ";
}
}
this.userCode = "\n void main() {\n ivec2 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n " + unrolled + "\n\n " + glsl.output + " = result;\n }\n ";
}
return Im2ColPackedProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LRNProgram = function() {
function LRNProgram2(xShape, radius, bias, alpha, beta) {
this.variableNames = ["x"];
this.outputShape = [];
var rad = radius;
var maxD = xShape[3] - 1;
this.outputShape = xShape;
var powOperator;
var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
if (beta === 0.5) {
powOperator = "inversesqrt(" + basis + ")";
} else if (beta === 1) {
powOperator = "1.0/(" + basis + ")";
} else {
powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
}
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -" + rad + "; j <= " + rad + "; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= " + maxD + ") {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * " + powOperator + ";\n setOutput(val);\n }\n ";
}
return LRNProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LRNGradProgram = function() {
function LRNGradProgram2(inputShape, depthRadius, bias, alpha, beta) {
this.variableNames = ["inputImage", "outputImage", "dy"];
this.outputShape = [];
this.outputShape = inputShape;
this.depth = inputShape[3];
this.depthRadius = depthRadius;
this.bias = bias;
this.alpha = alpha;
this.beta = beta;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < " + this.depth + "; ++d) {\n int depthBegin = int(max(0.0, float(d - " + depthRadius + ")));\n int depthEnd = int(min(float(" + this.depth + "),\n float(d + " + depthRadius + " + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = " + this.depth + ";\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(" + alpha + ") * norm + float(" + bias + ");\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(" + alpha + ")\n * float(" + beta + ")\n * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * " + beta + ");\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n ";
}
return LRNGradProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LRNPackedProgram = function() {
function LRNPackedProgram2(xShape, radius, bias, alpha, beta) {
this.variableNames = ["x"];
this.outputShape = [];
this.packedInputs = true;
this.packedOutput = true;
var rad = radius;
var maxD = xShape[3] - 1;
this.outputShape = xShape;
var powOperator;
var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
if (beta === 0.5) {
powOperator = "inversesqrt(" + basis + ")";
} else if (beta === 1) {
powOperator = "1.0/(" + basis + ")";
} else {
powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
}
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < " + this.outputShape[3] + ";\n bool hasNextRow = c < " + this.outputShape[2] + ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - " + rad + ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - " + rad + "; j <= " + rad + "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(" + maxD + "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * " + powOperator + ";\n setOutput(result);\n }\n ";
}
return LRNPackedProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MaxPool2DBackpropProgram = function() {
function MaxPool2DBackpropProgram2(convInfo) {
this.variableNames = ["dy", "maxPos"];
this.outputShape = convInfo.inShape;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = " + lastIndex + " - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n ";
}
return MaxPool2DBackpropProgram2;
}();
var MaxPool3DBackpropProgram = function() {
function MaxPool3DBackpropProgram2(convInfo) {
this.variableNames = ["dy", "maxPos"];
this.outputShape = convInfo.inShape;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = " + lastIndex + " -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
}
return MaxPool3DBackpropProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MatMulPackedProgram = function() {
function MatMulPackedProgram2(aShape, bShape, outputShape, transposeA, transposeB, addBias, activation, hasPreluActivation) {
if (transposeA === void 0) {
transposeA = false;
}
if (transposeB === void 0) {
transposeB = false;
}
if (addBias === void 0) {
addBias = false;
}
if (activation === void 0) {
activation = null;
}
if (hasPreluActivation === void 0) {
hasPreluActivation = false;
}
this.variableNames = ["matrixA", "matrixB"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = outputShape;
var sharedDim = transposeA ? aShape[1] : aShape[2];
var sharedDimensionPacked = Math.ceil(sharedDim / 2);
var aSample = transposeA ? "i * 2, rc.y" : "rc.y, i * 2";
var bSample = transposeB ? "rc.z, i * 2" : "i * 2, rc.z";
var aSwizzle = transposeA ? ["a.xxyy", "a.zzww"] : ["a.xxzz", "a.yyww"];
var bSwizzle = transposeB ? ["b.xzxz", "b.ywyw"] : ["b.xyxy", "b.zwzw"];
var activationSnippet = "", applyActivationSnippet = "";
if (activation) {
if (hasPreluActivation) {
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
} else {
activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
}
applyActivationSnippet = "result = activation(result);";
}
var addBiasSnippet = addBias ? "result += getBiasAtOutCoords();" : "";
if (addBias) {
this.variableNames.push("bias");
}
if (hasPreluActivation) {
this.variableNames.push("preluActivationWeights");
}
var batchASnippet = "rc.x";
var batchBSnippet = "rc.x";
if (aShape[0] < bShape[0]) {
batchASnippet = "int(min(float(rc.x), " + (aShape[0] - 1) + ".))";
} else if (bShape[0] < aShape[0]) {
batchBSnippet = "int(min(float(rc.x), " + (bShape[0] - 1) + ".))";
}
this.userCode = "\n " + activationSnippet + "\n\n const float sharedDimension = " + sharedDimensionPacked + ".0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n for (int i = 0; i < " + sharedDimensionPacked + "; i++) {\n int batchA = " + batchASnippet + ";\n int batchB = " + batchBSnippet + ";\n vec4 a = getMatrixA(batchA, " + aSample + ");\n vec4 b = getMatrixB(batchB, " + bSample + ");\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (" + aSwizzle[0] + " * " + bSwizzle[0] + ");\n result += (" + aSwizzle[1] + " * " + bSwizzle[1] + ");\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n " + addBiasSnippet + "\n\n " + applyActivationSnippet + "\n\n setOutput(result);\n }\n ";
}
return MatMulPackedProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MultinomialProgram = function() {
function MultinomialProgram2(batchSize, numOutcomes, numSamples) {
this.variableNames = ["probs"];
this.outputShape = [batchSize, numSamples];
this.userCode = "\n uniform float seed;\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < " + (numOutcomes - 1) + "; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(" + (numOutcomes - 1) + "));\n }\n ";
}
MultinomialProgram2.prototype.getCustomSetupFunc = function(seed) {
var _this = this;
return function(gpgpu, webGLProgram) {
if (_this.seedLoc == null) {
_this.seedLoc = gpgpu.getUniformLocation(webGLProgram, "seed");
}
gpgpu.gl.uniform1f(_this.seedLoc, seed);
};
};
return MultinomialProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var OneHotProgram = function() {
function OneHotProgram2(numIndices, depth, onValue, offValue) {
this.variableNames = ["indices"];
this.outputShape = [numIndices, depth];
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(" + offValue + "), float(" + onValue + "),\n float(index == coords.y)));\n }\n ";
}
return OneHotProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PackProgram = function() {
function PackProgram2(outputShape) {
this.variableNames = ["A"];
this.packedInputs = false;
this.packedOutput = true;
this.outputShape = outputShape;
var rank = outputShape.length;
if (rank === 0) {
this.userCode = "\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n ";
} else {
var channels = getChannels("rc", rank);
var dtype = getCoordsDataType(rank);
var outOfBoundsCondition = getOutOfBoundsCondition(rank, outputShape, channels);
var setup = getSetup(rank, outputShape[outputShape.length - 1], outputShape[outputShape.length - 2], channels);
var output = getOutput(outputShape, channels);
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n\n if(" + outOfBoundsCondition + ") {\n setOutput(vec4(0));\n } else {\n " + setup + "\n\n setOutput(vec4(" + output + "));\n }\n }\n ";
}
}
return PackProgram2;
}();
function getSourceCoordsArr(rank, dims) {
var coords2 = [];
for (var row = 0; row <= 1; row++) {
for (var col = 0; col <= 1; col++) {
var coord = (row === 0 ? "r" : "rp1") + ", " + (col === 0 ? "c" : "cp1");
for (var d = 2; d < rank; d++) {
coord = dims[dims.length - 1 - d] + "," + coord;
}
coords2.push(coord);
}
}
return coords2;
}
function getOutOfBoundsCondition(rank, shape, dims) {
if (rank === 1) {
return "rc > " + shape[0];
}
var cond = "";
for (var i = rank - 2; i < rank; i++) {
cond += dims[i] + " >= " + shape[i];
if (i < rank - 1) {
cond += "||";
}
}
return cond;
}
function getSetup(rank, cols, rows, dims) {
if (rank === 1) {
return "";
}
var innerDims = dims.slice(-2);
return "\n int r = " + innerDims[0] + ";\n int c = " + innerDims[1] + ";\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= " + cols + ";\n bool rEdge = rp1 >= " + rows + ";\n ";
}
function getOutput(shape, dims) {
var rank = shape.length;
var sourceCoords = getSourceCoordsArr(rank, dims);
if (rank === 1) {
return "getA(rc),\n rc + 1 >= " + shape[0] + " ? 0. : getA(rc + 1),\n 0, 0";
}
return "getA(" + sourceCoords[0] + "),\n cEdge ? 0. : getA(" + sourceCoords[1] + "),\n rEdge ? 0. : getA(" + sourceCoords[2] + "),\n rEdge || cEdge ? 0. : getA(" + sourceCoords[3] + ")";
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PadProgram = function() {
function PadProgram2(xShape, paddings, constantValue) {
this.variableNames = ["x"];
this.outputShape = paddings.map(function(p, i) {
return p[0] + xShape[i] + p[1];
});
var rank = xShape.length;
var type = getCoordsDataType(rank);
var start = paddings.map(function(p) {
return p[0];
}).join(",");
var end = paddings.map(function(p, i) {
return p[0] + xShape[i];
}).join(",");
var unpackedCoords = ["coords[0]", "coords[1]", "coords[2]", "coords[3]"].slice(0, rank);
if (rank === 1) {
this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(float(" + constantValue + "));\n } else {\n setOutput(getX(outC - start));\n }\n }\n ";
return;
}
this.userCode = "\n " + type + " start = " + type + "(" + start + ");\n " + type + " end = " + type + "(" + end + ");\n\n void main() {\n " + type + " outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(float(" + constantValue + "));\n } else {\n " + type + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n }\n ";
}
return PadProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PadPackedProgram = function() {
function PadPackedProgram2(xShape, paddings, constantValue) {
this.variableNames = ["x"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = paddings.map(function(p, i2) {
return p[0] + xShape[i2] + p[1];
});
var rank = xShape.length;
var dtype = getCoordsDataType(rank);
var start = paddings.map(function(p) {
return p[0];
}).join(",");
var end = paddings.map(function(p, i2) {
return p[0] + xShape[i2];
}).join(",");
var coords2 = getChannels("rc", rank);
var source = getChannels("source", rank);
var cLimit = coords2[rank - 1] + " < " + this.outputShape[rank - 1];
var innerDims = rank === 1 ? "source" : "vec2(" + source.slice(-2).join() + ")";
var componentSetup = [
dtype + " rc = outputLoc;",
coords2[rank - 1] + " += 1;\n if(" + cLimit + ") {\n ",
rank === 1 ? "" : "}\n rc = outputLoc;\n " + coords2[rank - 2] + " += 1;\n if(" + coords2[rank - 2] + " < " + this.outputShape[rank - 2] + ") {",
rank === 1 ? "" : " " + coords2[rank - 1] + " += 1;\n if(" + cLimit + ") {"
];
var paddingArea = rank === 1 ? "rc < start || rc >= end" : "any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))";
var mainLoop = "";
for (var i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
mainLoop += "\n " + componentSetup[i] + "\n if (" + paddingArea + ") {\n result[" + i + "] = float(" + constantValue + ");\n } else {\n " + dtype + " source = rc - start;\n result[" + i + "] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n ";
}
mainLoop += rank === 1 ? "} " : "}}";
this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n ";
}
return PadPackedProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Pool2DProgram = function() {
function Pool2DProgram2(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) {
if (flattenPositions === void 0) {
flattenPositions = false;
}
if (includeBatchInIndex === void 0) {
includeBatchInIndex = false;
}
this.variableNames = ["x"];
if (poolType === "avg" && computePositions) {
throw new Error("Cannot compute positions for average pool.");
}
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
this.outputShape = convInfo.outShape;
var isAvgPool = poolType === "avg";
var batchFlattenPositionStr = "((batch * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d";
var flattenPositionStr = "(xR * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d";
var initializationValue = "0.0";
if (!isAvgPool) {
initializationValue = "-1.0 / 1e-20";
}
if (computePositions) {
var compareOp_1 = ">=";
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_1 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ? includeBatchInIndex ? batchFlattenPositionStr : flattenPositionStr : "wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ";
return;
}
var compareOp = "max";
var returnValue = poolType + "(" + poolType + "(" + poolType + "(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])";
if (poolType === "avg") {
returnValue = "avgValue / count";
}
var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
var filterWidthVec4Remainder = filterWidth % 4;
var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n getValue(batch, xR, xC + 3 * " + dilationWidth + ", d)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n ";
}
return Pool2DProgram2;
}();
var Pool3DProgram = function() {
function Pool3DProgram2(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) {
if (flattenPositions === void 0) {
flattenPositions = false;
}
if (includeBatchInIndex === void 0) {
includeBatchInIndex = false;
}
this.variableNames = ["x"];
if (poolType === "avg" && computePositions) {
throw new Error("Cannot compute positions for average pool.");
}
var filterWidth = convInfo.filterWidth;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
this.outputShape = convInfo.outShape;
var isAvgPool = poolType === "avg";
var initializationValue = "0.0";
if (!isAvgPool) {
initializationValue = "-1.0 / 1e-20";
}
if (computePositions) {
var compareOp_2 = ">=";
this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_2 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ? includeBatchInIndex ? "(((batch * " + convInfo.inDepth + " + xD) * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch" : "((xD * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch" : "wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ";
return;
}
var compareOp = "max";
var returnValue = poolType + "(" + poolType + "(" + poolType + "(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])";
if (poolType === "avg") {
returnValue = "avgValue / count";
}
var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
var filterWidthVec4Remainder = filterWidth % 4;
var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 3 * " + dilationWidth + ", ch)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n }\n ";
}
return Pool3DProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ReduceProgram = function() {
function ReduceProgram2(reduceInfo, reduceType) {
this.variableNames = ["x"];
var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, inSize = reduceInfo.inSize, outSize = reduceInfo.outSize;
this.outputShape = [batchSize, outSize];
var initializationValue = "0.0";
var compareOp = "";
if (reduceType === "prod") {
initializationValue = "1.0";
} else if (reduceType === "min") {
initializationValue = "1.0 / 1e-20";
compareOp = "min";
} else if (reduceType === "max") {
initializationValue = "-1.0 / 1e-20";
compareOp = "max";
}
var returnValue = reduceType + "(" + reduceType + "(" + reduceType + "(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])";
if (reduceType === "sum") {
returnValue = "sumValue";
} else if (reduceType === "prod") {
returnValue = "prodValue";
} else if (reduceType === "all") {
returnValue = "allValue";
} else if (reduceType === "any") {
returnValue = "anyValue";
}
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
var windowSizeVec4Remainder = windowSize % 4;
var updateSnippet = "\n if (" + (reduceType === "sum") + ") {\n sumValue += dot(values, ones);\n } else if (" + (reduceType === "prod") + ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
var vecType = "vec4";
if (reduceType === "all") {
initializationValue = "1.0";
updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ";
vecType = "bvec4";
} else if (reduceType === "any") {
initializationValue = "0.0";
updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ";
vecType = "bvec4";
}
var checkOutOfBounds = "";
if (inSize % windowSize > 0) {
checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
}
this.userCode = "\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
}
return ReduceProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ReshapePackedProgram = function() {
function ReshapePackedProgram2(outputShape, inputShape) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = outputShape;
var mainLoop = "";
for (var i = 0; i < 4; i++) {
var thisRC = "thisRC = rc;";
if (i % 2 === 1) {
thisRC += "thisRC.z += 1;";
}
if (i > 1) {
thisRC += "thisRC.y += 1;";
}
mainLoop += "\n " + thisRC + "\n " + (i > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : "") + "\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[" + i + "] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n " + (i > 0 ? "}" : "") + "\n ";
}
this.userCode = "\n " + getReshapedInputCoords(inputShape) + "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = " + outputShape[1] + ";\n int cols = " + outputShape[2] + ";\n\n " + mainLoop + "\n\n setOutput(result);\n }\n ";
}
return ReshapePackedProgram2;
}();
function getReshapedInputCoords(shape) {
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(["r", "c", "d"], shape);
return "\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n ";
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeBilinearBackpropProgram = function() {
function ResizeBilinearBackpropProgram2(dy, x, alignCorners) {
this.variableNames = ["dy"];
this.outputShape = [];
this.outputShape = x.shape;
var _a = x.shape, xHeight = _a[1], xWidth = _a[2];
var _b = dy.shape, yHeight = _b[1], yWidth = _b[2];
var effectiveXSize = [
alignCorners && yHeight > 1 ? xHeight - 1 : xHeight,
alignCorners && yWidth > 1 ? xWidth - 1 : xWidth
];
var effectiveYSize = [
alignCorners && yHeight > 1 ? yHeight - 1 : yHeight,
alignCorners && yWidth > 1 ? yWidth - 1 : yWidth
];
var heightScale = effectiveXSize[0] / effectiveYSize[0];
var widthScale = effectiveXSize[1] / effectiveYSize[1];
var invHeightScale = 1 / heightScale;
var invWidthScale = 1 / widthScale;
var winHeight = Math.ceil(invHeightScale) * 2 + 2;
var winWidth = Math.ceil(invWidthScale) * 2 + 2;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), " + (xHeight - 1) + ".0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), " + (xWidth - 1) + ".0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ";
}
return ResizeBilinearBackpropProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeBilinearProgram = function() {
function ResizeBilinearProgram2(inputShape, newHeight, newWidth, alignCorners) {
this.variableNames = ["A"];
this.outputShape = [];
var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
this.outputShape = [batch, newHeight, newWidth, depth];
var effectiveInSize = [
alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight,
alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth
];
var effectiveOutSize = [
alignCorners && newHeight > 1 ? newHeight - 1 : newHeight,
alignCorners && newWidth > 1 ? newWidth - 1 : newWidth
];
this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n ";
}
return ResizeBilinearProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeBilinearPackedProgram = function() {
function ResizeBilinearPackedProgram2(inputShape, newHeight, newWidth, alignCorners) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = [];
var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
this.outputShape = [batch, newHeight, newWidth, depth];
var effectiveInSize = [
alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight,
alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth
];
var effectiveOutSize = [
alignCorners && newHeight > 1 ? newHeight - 1 : newHeight,
alignCorners && newWidth > 1 ? newWidth - 1 : newWidth
];
this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec3 inputShapeRC = vec3(" + oldHeight + ".0, " + oldWidth + ".0,\n " + oldWidth + ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = vec3(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(sourceFracIndexRC);\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < " + (depth - 1) + ";\n bool hasNextRow = coords.z < " + (newWidth - 1) + ";\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n ";
}
return ResizeBilinearPackedProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeNearestNeigborBackpropProgram = function() {
function ResizeNearestNeigborBackpropProgram2(dy, x, alignCorners) {
this.variableNames = ["dy"];
this.outputShape = [];
this.outputShape = x.shape;
var _a = x.shape, xHeight = _a[1], xWidth = _a[2];
var _b = dy.shape, yHeight = _b[1], yWidth = _b[2];
var effectiveXSize = [
alignCorners && yHeight > 1 ? xHeight - 1 : xHeight,
alignCorners && yWidth > 1 ? xWidth - 1 : xWidth
];
var effectiveYSize = [
alignCorners && yHeight > 1 ? yHeight - 1 : yHeight,
alignCorners && yWidth > 1 ? yWidth - 1 : yWidth
];
var heightScale = effectiveXSize[0] / effectiveYSize[0];
var widthScale = effectiveXSize[1] / effectiveYSize[1];
var invHeightScale = 1 / heightScale;
var invWidthScale = 1 / widthScale;
var winHeight = Math.ceil(invHeightScale) * 2 + 2;
var winWidth = Math.ceil(invWidthScale) * 2 + 2;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float sourceFracRow =\n float(" + effectiveXSize[0] + ") *\n (float(dyR) / float(" + effectiveYSize[0] + "));\n\n float sourceFracCol =\n float(" + effectiveXSize[1] + ") *\n (float(dyC) / float(" + effectiveYSize[1] + "));\n\n int sourceNearestRow = int(min(\n float(int(" + xHeight + ") - 1),\n " + alignCorners + " ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(" + xWidth + ") - 1),\n " + alignCorners + " ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ";
}
return ResizeNearestNeigborBackpropProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeNearestNeighborProgram = function() {
function ResizeNearestNeighborProgram2(inputShape, newHeight, newWidth, alignCorners) {
this.variableNames = ["A"];
this.outputShape = [];
var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
this.outputShape = [batch, newHeight, newWidth, depth];
var effectiveInSize = [
alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight,
alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth
];
var effectiveOutSize = [
alignCorners && newHeight > 1 ? newHeight - 1 : newHeight,
alignCorners && newWidth > 1 ? newWidth - 1 : newWidth
];
var roundBase = alignCorners ? "0.5" : "0.0";
this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + " + roundBase + ")));\n\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n ";
}
return ResizeNearestNeighborProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ReverseProgram = function() {
function ReverseProgram2(xShape, axis) {
this.variableNames = ["x"];
var rank = xShape.length;
if (rank > 4) {
throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported");
}
this.outputShape = xShape;
if (rank === 1) {
this.userCode = "\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(" + xShape[0] + " - coord - 1));\n }\n ";
return;
}
var getInCoord = function(i) {
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
return xShape[i] + " - coords[" + i + "] - 1";
}
return "coords[" + i + "]";
};
var inCoords = xShape.map(function(_, i) {
return getInCoord(i);
}).join(",");
var type = getCoordsDataType(rank);
this.userCode = "\n void main() {\n " + type + " coords = getOutputCoords();\n setOutput(getX(" + inCoords + "));\n }\n ";
}
return ReverseProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ReversePackedProgram = function() {
function ReversePackedProgram2(xShape, axis) {
this.variableNames = ["x"];
this.packedInputs = true;
this.packedOutput = true;
var rank = xShape.length;
if (rank > 4) {
throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported");
}
this.outputShape = xShape;
var channels = getChannels("rc", rank);
var nextColumn = channels[rank - 1] + " + 1 < " + this.outputShape[rank - 1];
var nextRow = channels[rank - 2] + " + 1 < " + this.outputShape[rank - 2];
var type = getCoordsDataType(rank);
if (rank === 1) {
this.userCode = "\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(" + xShape[0] + " - rc - 1),\n " + xShape[0] + " - rc - 1);\n if(" + nextColumn + "){\n result.g = getChannel(getX(" + xShape[0] + " - (rc + 1) - 1),\n " + xShape[0] + " - (rc + 1) - 1);\n }\n setOutput(result);\n }\n ";
} else {
this.userCode = "\n void main() {\n " + type + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = " + getR(channels.slice()) + ";\n if(" + nextColumn + "){\n result.g = " + getG(channels.slice()) + ";\n }\n if(" + nextRow + ") {\n result.b = " + getB(channels.slice()) + ";\n if(" + nextColumn + ") {\n result.a = " + getA(channels.slice()) + ";\n }\n }\n setOutput(result);\n }\n ";
}
function getR(channels2) {
return getChannel(channels2);
}
function getG(channels2) {
channels2[rank - 1] = "(" + channels2[rank - 1] + " + 1)";
return getChannel(channels2);
}
function getB(channels2) {
channels2[rank - 2] = "(" + channels2[rank - 2] + " + 1)";
return getChannel(channels2);
}
function getA(channels2) {
channels2[rank - 1] = "(" + channels2[rank - 1] + " + 1)";
channels2[rank - 2] = "(" + channels2[rank - 2] + " + 1)";
return getChannel(channels2);
}
function getChannel(channels2) {
var inCoordsArray = xShape.map(function(_, i) {
return getInCoord(i, channels2);
});
var inCoords = inCoordsArray.join(",");
var innerDims = inCoordsArray.slice(-2).join(",");
return "getChannel(getX(" + inCoords + "), vec2(" + innerDims + "))";
}
function getInCoord(i, channels1) {
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
return xShape[i] + " - " + channels1[i] + " - 1";
} else {
return "" + channels1[i];
}
}
}
return ReversePackedProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ScatterProgram = function() {
function ScatterProgram2(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex) {
this.variableNames = ["updates", "indices", "defaultValue"];
this.outputShape = shape;
var stridesType = getCoordsDataType(strides.length);
var dtype = getCoordsDataType(shape.length);
var indicesString = "";
if (indicesRank === 1) {
indicesString = "i";
} else if (indicesRank === 2) {
indicesString = "i, j";
}
var indicesSnippet = "getIndices(" + indicesString + ")";
var updatesString = "";
if (updatesRank === 1) {
updatesString = "i";
} else if (updatesRank === 2) {
updatesString = "i, coords[1]";
}
var updatesSnippet = "getUpdates(" + updatesString + ")";
var strideString = sliceDim > 1 ? "strides[j]" : "strides";
this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < " + updateSize + "; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < " + sliceDim + "; j++) {\n int index = round(" + indicesSnippet + ");\n flattenedIndex += index * " + strideString + ";\n }\n if (flattenedIndex == coords[0]) {\n sum += " + updatesSnippet + ";\n found = true;\n }\n }\n setOutput(mix(getDefaultValue(), sum, float(found)));\n }\n ";
}
return ScatterProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SegmentOpProgram = function() {
function SegmentOpProgram2(segOpInfo, segOpType) {
this.variableNames = ["x", "segmentIds"];
var windowSize = segOpInfo.windowSize;
var batchSize = segOpInfo.batchSize;
var inSize = segOpInfo.inSize;
var numSegments = segOpInfo.numSegments;
var outSize = numSegments * Math.ceil(inSize / windowSize);
this.outputShape = [batchSize, outSize];
var initializationValue = "0.0";
var returnValue = "sumValue";
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
var windowSizeVec4Remainder = windowSize % 4;
var updateSnippet = "\n sumValue += dot(values, segFilter);\n ";
var checkValueOutOfBounds = "";
if (inSize % windowSize > 0) {
checkValueOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
}
var checkSegmentIdOutOfBounds = "";
if (inSize % windowSize > 0) {
checkSegmentIdOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return -1.0;\n }\n ";
}
this.userCode = "\n const float initializationValue = " + initializationValue + ";\n\n float getValue(int batch, int inIdx) {\n " + checkValueOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n " + checkSegmentIdOutOfBounds + "\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n " + numSegments + ")) * float(" + windowSize + "));\n int currentSeg = int(mod(float(outIdx), float(" + numSegments + ")));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
}
return SegmentOpProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SelectProgram = function() {
function SelectProgram2(cRank, shape, rank) {
this.variableNames = ["c", "a", "b"];
this.outputShape = shape;
var cCoords;
var abCoords;
if (rank > 4) {
throw Error("Where for rank " + rank + " is not yet supported");
}
if (rank === 1) {
abCoords = "resRC";
cCoords = "resRC";
} else {
var currentCoords = ["resRC.x", "resRC.y", "resRC.z", "resRC.w"];
var cCoordVars = [];
var abCoordVars = [];
for (var i = 0; i < shape.length; i++) {
abCoordVars.push("" + currentCoords[i]);
if (i < cRank) {
cCoordVars.push("" + currentCoords[i]);
}
}
cCoords = cCoordVars.join();
abCoords = abCoordVars.join();
}
var dtype = getCoordsDataType(rank);
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n float cVal = getC(" + cCoords + ");\n if (cVal >= 1.0) {\n setOutput(getA(" + abCoords + "));\n } else {\n setOutput(getB(" + abCoords + "));\n }\n }\n ";
}
return SelectProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SliceProgram = function() {
function SliceProgram2(destSize) {
this.variableNames = ["source"];
this.outputShape = destSize;
this.rank = destSize.length;
var dtype = getCoordsDataType(this.rank);
var uniformPart = "uniform int start[" + this.rank + "];";
var sourceCoords = getCoords$1(this.rank);
var body;
var coordSum = destSize.map(function(_, i) {
return "sourceLoc." + coords[i] + " = start[" + i + "] + coords." + coords[i] + ";";
});
body = "\n " + dtype + " sourceLoc;\n " + dtype + " coords = getOutputCoords();\n " + coordSum.join("\n") + "\n ";
this.userCode = "\n " + uniformPart + "\n void main() {\n " + body + "\n setOutput(getSource(" + sourceCoords + "));\n }\n ";
}
SliceProgram2.prototype.getCustomSetupFunc = function(start) {
var _this = this;
if (start.length !== this.rank) {
throw Error("The rank (" + this.rank + ") of the program must match the " + ("length of start (" + start.length + ")"));
}
return function(gpgpu, webGLProgram) {
if (_this.startLoc == null) {
_this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, "start");
if (_this.startLoc == null) {
return;
}
}
gpgpu.gl.uniform1iv(_this.startLoc, start);
};
};
return SliceProgram2;
}();
var coords = ["x", "y", "z", "w", "u", "v"];
function getCoords$1(rank) {
if (rank === 1) {
return "sourceLoc";
} else if (rank <= 6) {
return coords.slice(0, rank).map(function(x) {
return "sourceLoc." + x;
}).join(",");
} else {
throw Error("Slicing for rank " + rank + " is not yet supported");
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SlicePackedProgram = function() {
function SlicePackedProgram2(destSize) {
this.variableNames = ["source"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = destSize;
this.rank = destSize.length;
var dtype = getCoordsDataType(this.rank);
var coords2 = getChannels("coords", this.rank);
var sourceLoc = getChannels("sourceLoc", this.rank);
var innerDims = this.rank === 1 ? "sourceLoc" : "vec2(" + sourceLoc.slice(-2).join() + ")";
var getChannel = "getChannel(getSource(" + sourceLoc.join() + "), " + innerDims + ")";
var upperRow = "\n result.x = " + getChannel + ";\n if (++" + coords2[this.rank - 1] + " < " + destSize[this.rank - 1] + ") {\n ++" + sourceLoc[this.rank - 1] + ";\n result.y = " + getChannel + ";\n --" + sourceLoc[this.rank - 1] + ";\n }\n ";
var lowerRow = this.rank === 1 ? "" : "\n --" + coords2[this.rank - 1] + ";\n if (++" + coords2[this.rank - 2] + " < " + destSize[this.rank - 2] + ") {\n ++" + sourceLoc[this.rank - 2] + ";\n result.z = " + getChannel + ";\n if (++" + coords2[this.rank - 1] + " < " + destSize[this.rank - 1] + ") {\n ++" + sourceLoc[this.rank - 1] + ";\n result.w = " + getChannel + ";\n }\n }\n ";
var sourceLocSetup = this.rank <= 4 ? "sourceLoc = coords +\n " + dtype + "(" + destSize.map(function(_, i) {
return "start[" + i + "]";
}).join() + ");" : destSize.map(function(_, i) {
return sourceLoc[i] + " = " + coords2[i] + " + start[" + i + "];";
}).join("\n");
this.userCode = "\n uniform int start[" + this.rank + "];\n void main() {\n " + dtype + " coords = getOutputCoords();\n " + dtype + " sourceLoc;\n " + sourceLocSetup + "\n vec4 result = vec4(0.);\n " + upperRow + "\n " + lowerRow + "\n setOutput(result);\n }\n ";
}
SlicePackedProgram2.prototype.getCustomSetupFunc = function(start) {
var _this = this;
if (start.length !== this.rank) {
throw Error("The rank (" + this.rank + ") of the program must match the " + ("length of start (" + start.length + ")"));
}
return function(gpgpu, webGLProgram) {
if (_this.startLoc == null) {
_this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, "start");
if (_this.startLoc == null) {
return;
}
}
gpgpu.gl.uniform1iv(_this.startLoc, start);
};
};
return SlicePackedProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var StridedSliceProgram = function() {
function StridedSliceProgram2(begin, strides, size) {
this.variableNames = ["x"];
this.outputShape = size;
var rank = size.length;
var inputDtype = getCoordsDataType(size.length);
var dtype = getCoordsDataType(size.length);
var newCoords = "";
if (rank === 1) {
newCoords = "coords * strides + begin";
} else {
var outputAxis_1 = 0;
newCoords = size.map(function(_, i) {
outputAxis_1++;
return size.length === 1 ? "coords * strides[" + i + "] + begin[" + i + "]" : "coords[" + (outputAxis_1 - 1) + "] * strides[" + i + "] + begin[" + i + "]";
}).join(",");
}
this.userCode = "\n " + inputDtype + " begin = " + inputDtype + "(" + begin + ");\n " + inputDtype + " strides = " + inputDtype + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n setOutput(getX(" + newCoords + "));\n }\n ";
}
return StridedSliceProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TextureManager = function() {
function TextureManager2(gpgpu) {
this.gpgpu = gpgpu;
this.numUsedTextures = 0;
this.numFreeTextures = 0;
this._numBytesAllocated = 0;
this._numBytesFree = 0;
this.freeTextures = {};
this.logEnabled = false;
this.usedTextures = {};
}
TextureManager2.prototype.acquireTexture = function(shapeRC, usage, isPacked) {
var physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
var shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
if (!(shapeKey in this.freeTextures)) {
this.freeTextures[shapeKey] = [];
}
if (!(shapeKey in this.usedTextures)) {
this.usedTextures[shapeKey] = [];
}
var texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
if (this.freeTextures[shapeKey].length > 0) {
this.numFreeTextures--;
this.numUsedTextures++;
this._numBytesFree -= texBytes;
this.log();
var newTexture_1 = this.freeTextures[shapeKey].shift();
this.usedTextures[shapeKey].push(newTexture_1);
return newTexture_1;
}
var newTexture;
if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
} else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
newTexture = this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
} else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
newTexture = this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
} else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
newTexture = this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
} else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
newTexture = this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
}
this.usedTextures[shapeKey].push(newTexture);
this.numUsedTextures++;
this._numBytesAllocated += texBytes;
this.log();
return newTexture;
};
TextureManager2.prototype.releaseTexture = function(texture, shape, logicalTexType, isPacked) {
if (this.freeTextures == null) {
return;
}
var physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
var shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
if (!(shapeKey in this.freeTextures)) {
this.freeTextures[shapeKey] = [];
}
var texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
var deleteTexThreshold = tf.env().get("WEBGL_DELETE_TEXTURE_THRESHOLD");
if (deleteTexThreshold !== -1 && this._numBytesAllocated > deleteTexThreshold) {
this.gpgpu.deleteMatrixTexture(texture);
this._numBytesAllocated -= texBytes;
} else {
this.freeTextures[shapeKey].push(texture);
this.numFreeTextures++;
this._numBytesFree += texBytes;
}
this.numUsedTextures--;
var texList = this.usedTextures[shapeKey];
var texIndex = texList.indexOf(texture);
if (texIndex < 0) {
throw new Error("Cannot release a texture that was never provided by this texture manager");
}
texList.splice(texIndex, 1);
this.log();
};
TextureManager2.prototype.log = function() {
if (!this.logEnabled) {
return;
}
var total = this.numFreeTextures + this.numUsedTextures;
console.log("Free/Used", this.numFreeTextures + " / " + this.numUsedTextures, "(" + total + ")");
var freeRatio = this._numBytesFree / this._numBytesAllocated;
console.log("Bytes allocated: " + this._numBytesAllocated);
console.log("Bytes unused: " + this._numBytesFree + " (" + Math.round(100 * freeRatio) + "%)");
};
Object.defineProperty(TextureManager2.prototype, "numBytesAllocated", {
get: function() {
return this._numBytesAllocated;
},
enumerable: true,
configurable: true
});
Object.defineProperty(TextureManager2.prototype, "numBytesFree", {
get: function() {
return this._numBytesFree;
},
enumerable: true,
configurable: true
});
TextureManager2.prototype.getNumUsedTextures = function() {
return this.numUsedTextures;
};
TextureManager2.prototype.getNumFreeTextures = function() {
return this.numFreeTextures;
};
TextureManager2.prototype.dispose = function() {
var _this = this;
if (this.freeTextures == null) {
return;
}
for (var texShape in this.freeTextures) {
this.freeTextures[texShape].forEach(function(tex) {
_this.gpgpu.deleteMatrixTexture(tex);
});
}
for (var texShape in this.usedTextures) {
this.usedTextures[texShape].forEach(function(tex) {
_this.gpgpu.deleteMatrixTexture(tex);
});
}
this.freeTextures = null;
this.usedTextures = null;
this.numUsedTextures = 0;
this.numFreeTextures = 0;
this._numBytesAllocated = 0;
this._numBytesFree = 0;
};
return TextureManager2;
}();
function numBytesForInternalFormat(gl, internalFormat) {
var glany = gl;
if (internalFormat === glany.R32F) {
return 4;
} else if (internalFormat === glany.R16F) {
return 2;
} else if (internalFormat === glany.RGBA32F) {
return 16;
} else if (internalFormat === gl.RGBA) {
return 16;
} else if (internalFormat === glany.RGBA16F) {
return 8;
}
throw new Error("Unknown internal format " + internalFormat);
}
function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
var internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
var numElements;
if (isPacked) {
var _a = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]), packedWidth = _a[0], packedHeight = _a[1];
numElements = packedWidth * packedHeight;
} else {
var _b = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]), width = _b[0], height = _b[1];
numElements = width * height;
}
var bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
return numElements * bytesPerElement;
}
function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
switch (physicalTexType) {
case PhysicalTextureType.PACKED_2X2_FLOAT32:
return getInternalFormatForPackedMatrixTexture(textureConfig);
case PhysicalTextureType.PACKED_2X2_FLOAT16:
return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
case PhysicalTextureType.UNPACKED_FLOAT32:
return getInternalFormatForFloat32MatrixTexture(textureConfig);
case PhysicalTextureType.UNPACKED_FLOAT16:
return getInternalFormatForFloat16MatrixTexture(textureConfig);
case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
default:
throw new Error("Unknown physical texture type " + physicalTexType);
}
}
function getPhysicalTextureForRendering(isPacked) {
if (tf.env().getBool("WEBGL_RENDER_FLOAT32_ENABLED")) {
if (isPacked) {
return PhysicalTextureType.PACKED_2X2_FLOAT32;
}
return PhysicalTextureType.UNPACKED_FLOAT32;
}
if (isPacked) {
return PhysicalTextureType.PACKED_2X2_FLOAT16;
}
return PhysicalTextureType.UNPACKED_FLOAT16;
}
function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
if (logicalTexType === TextureUsage.UPLOAD) {
return PhysicalTextureType.PACKED_2X2_FLOAT32;
} else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
return getPhysicalTextureForRendering(isPacked);
} else if (logicalTexType === TextureUsage.DOWNLOAD || logicalTexType === TextureUsage.PIXELS) {
return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
}
throw new Error("Unknown logical texture type " + logicalTexType);
}
function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
return shapeRowsCol[0] + "_" + shapeRowsCol[1] + "_" + physicalTexType + "_" + isPacked;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TileProgram = function() {
function TileProgram2(aShape, reps) {
this.variableNames = ["A"];
var outputShape = new Array(aShape.length);
for (var i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[i] * reps[i];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
var dtype = getCoordsDataType(this.rank);
var sourceCoords = getSourceCoords$2(aShape);
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n ";
}
return TileProgram2;
}();
function getSourceCoords$2(aShape) {
var rank = aShape.length;
if (rank > 5) {
throw Error("Tile for rank " + rank + " is not yet supported");
}
if (rank === 1) {
return "imod(resRC, " + aShape[0] + ")";
}
var currentCoords = ["resRC.x", "resRC.y", "resRC.z", "resRC.w", "resRC.u"];
var sourceCoords = [];
for (var i = 0; i < aShape.length; i++) {
sourceCoords.push("imod(" + currentCoords[i] + ", " + aShape[i] + ")");
}
return sourceCoords.join();
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var UnaryOpProgram = function() {
function UnaryOpProgram2(aShape, opSnippet) {
this.variableNames = ["A"];
this.outputShape = aShape;
this.userCode = "\n float unaryOperation(float x) {\n " + opSnippet + "\n }\n\n void main() {\n float x = getAAtOutCoords();\n float y = unaryOperation(x);\n\n setOutput(y);\n }\n ";
}
return UnaryOpProgram2;
}();
var CHECK_NAN_SNIPPET$2 = "if (isnan(x)) return x;";
var LINEAR = "return x;";
var ABS = "return abs(x);";
var RELU = CHECK_NAN_SNIPPET$2 + "\n return (x < 0.0) ? 0.0 : x;\n";
var RELU6 = CHECK_NAN_SNIPPET$2 + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n";
var ELU = "return (x >= 0.0) ? x : (exp(x) - 1.0);";
var SELU = "\n // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.\n // see: https://arxiv.org/abs/1706.02515\n float scaleAlpha = " + tf.backend_util.SELU_SCALEALPHA + ";\n float scale = " + tf.backend_util.SELU_SCALE + ";\n return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);\n";
function STEP(alpha) {
if (alpha === void 0) {
alpha = 0;
}
return CHECK_NAN_SNIPPET$2 + ("\n return x > 0.0 ? 1.0 : float(" + alpha + ");\n ");
}
var NEG = "return -x;";
var CEIL = "return ceil(x);";
var FLOOR = "return floor(x);";
var SIGN = "\n if (isnan(x)) { return 0.0; }\n return sign(x);\n";
var IS_NAN = "return float(isnan(x));";
var IS_INF = "return float(isinf(x));";
var IS_FINITE = "return float(!isnan(x) && !isinf(x));";
var ROUND = "\n // OpenGL ES does not support round function.\n // The algorithm is based on banker's rounding.\n float base = floor(x);\n if ((x - base) < 0.5) {\n return floor(x);\n } else if ((x - base) > 0.5) {\n return ceil(x);\n } else {\n if (mod(base, 2.0) == 0.0) {\n return base;\n } else {\n return base + 1.0;\n }\n }\n";
var EXP = "return exp(x);";
var EXPM1 = "return exp(x) - 1.0;";
var LOG = "if (x < 0.0) return NAN;\n return log(x);";
var LOG1P = "return log(1.0 + x);";
var SQRT = "return sqrt(x);";
var RSQRT = "return inversesqrt(x);";
var SIGMOID = "return 1.0 / (1.0 + exp(-1.0 * x));";
var SOFTPLUS = "\n float epsilon = 1.1920928955078125e-7;\n float threshold = log(epsilon) + 2.0;\n\n bool too_large = x > -threshold;\n bool too_small = x < threshold;\n\n float result;\n float exp_x = exp(x);\n\n if (too_large){\n result = x;\n }\n else if (too_small){\n result = exp_x;\n }\n else{\n result = log(exp_x + 1.0);\n }\n return result;\n";
var ASIN = CHECK_NAN_SNIPPET$2 + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return asin(x);\n";
var ACOS = CHECK_NAN_SNIPPET$2 + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return acos(x);\n";
var ATAN = CHECK_NAN_SNIPPET$2 + "\n return atan(x);\n";
var SINH = "\n float e2x = exp(x);\n return (e2x - 1.0 / e2x) / 2.0;\n";
var COSH = "\n float e2x = exp(-x);\n return (e2x + 1.0 / e2x) / 2.0;\n";
var TANH = "\n float e2x = exp(-2.0 * abs(x));\n return sign(x) * (1.0 - e2x) / (1.0 + e2x);\n";
var ASINH = CHECK_NAN_SNIPPET$2 + "return log(x + sqrt(x * x + 1.0));";
var ACOSH = CHECK_NAN_SNIPPET$2 + "\n if (x < 1.0) return NAN;\n return log(x + sqrt(x * x - 1.0));";
var ATANH = CHECK_NAN_SNIPPET$2 + "\n if ((x < -1.0) || (x > 1.0)) return NAN;\n return (log(1.0 + x) - log(1.0 - x)) / 2.0;";
var ERF = '\n // Error function is calculated approximately with elementary function.\n // See "Handbook of Mathematical Functions with Formulas,\n // Graphs, and Mathematical Tables", Abramowitz and Stegun.\n float p = ' + tf.backend_util.ERF_P + ";\n float a1 = " + tf.backend_util.ERF_A1 + ";\n float a2 = " + tf.backend_util.ERF_A2 + ";\n float a3 = " + tf.backend_util.ERF_A3 + ";\n float a4 = " + tf.backend_util.ERF_A4 + ";\n float a5 = " + tf.backend_util.ERF_A5 + ";\n\n float sign = sign(x);\n x = abs(x);\n float t = 1.0 / (1.0 + p * x);\n return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));\n";
var RECIPROCAL = "return 1.0 / x;";
var LOGICAL_NOT = "return float(!(x >= 1.0));";
var CLONE = "return x;";
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LINEAR$1 = "return x;";
var LOG$1 = "\n vec4 result = log(x);\n vec4 isNaN = vec4(lessThan(x, vec4(0.0)));\n result.r = isNaN.r == 1.0 ? NAN : result.r;\n result.g = isNaN.g == 1.0 ? NAN : result.g;\n result.b = isNaN.b == 1.0 ? NAN : result.b;\n result.a = isNaN.a == 1.0 ? NAN : result.a;\n\n return result;\n";
var RELU$1 = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
var RELU6$1 = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
var ELU$1 = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n";
var UnaryOpPackedProgram = function() {
function UnaryOpPackedProgram2(aShape, opSnippet) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = aShape;
this.userCode = "\n vec4 unaryOperation(vec4 x) {\n " + opSnippet + "\n }\n\n void main() {\n vec4 x = getAAtOutCoords();\n vec4 y = unaryOperation(x);\n\n setOutput(y);\n }\n ";
}
return UnaryOpPackedProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var UnpackProgram = function() {
function UnpackProgram2(outputShape) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = false;
this.outputShape = outputShape;
var rank = outputShape.length;
var channels = getChannels("rc", rank);
var dtype = getCoordsDataType(rank);
var sourceCoords = getSourceCoords(rank, channels);
var innerDims = channels.slice(-2);
var coords2 = rank <= 1 ? "rc" : "vec2(" + innerDims.join(",") + ")";
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n vec4 packedInput = getA(" + sourceCoords + ");\n\n setOutput(getChannel(packedInput, " + coords2 + "));\n }\n ";
}
return UnpackProgram2;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var segment_util = tf.backend_util.segment_util;
var split = tf.kernel_impls.split;
var tile = tf.kernel_impls.tile;
var topkImpl = tf.kernel_impls.topkImpl;
var whereImpl = tf.kernel_impls.whereImpl;
var EPSILON_FLOAT32 = 1e-7;
var EPSILON_FLOAT16 = 1e-4;
var binaryCaches = {};
function getBinaryCache(webGLVersion) {
if (webGLVersion in binaryCaches) {
return binaryCaches[webGLVersion];
}
binaryCaches[webGLVersion] = {};
return binaryCaches[webGLVersion];
}
function mapActivationToShaderProgram(activation, packed) {
if (packed === void 0) {
packed = false;
}
if (activation === "linear") {
if (packed) {
return LINEAR$1;
}
return LINEAR;
} else if (activation === "relu") {
if (packed) {
return RELU$1;
}
return RELU;
} else if (activation === "elu") {
if (packed) {
return ELU$1;
}
return ELU;
} else if (activation === "relu6") {
if (packed) {
return RELU6$1;
}
return RELU6;
} else if (activation === "prelu") {
if (packed) {
return PRELU$1;
}
return PRELU;
}
throw new Error("Activation " + activation + " has not been implemented for the WebGL backend.");
}
var CPU_HANDOFF_SIZE_THRESHOLD = 128;
var BEFORE_PAGING_CONSTANT = 600;
function numMBBeforeWarning() {
if (tf.env().global.screen == null) {
return 1024;
}
return tf.env().global.screen.height * tf.env().global.screen.width * window.devicePixelRatio * BEFORE_PAGING_CONSTANT / 1024 / 1024;
}
var MATMUL_SHARED_DIM_THRESHOLD = 1e3;
var MathBackendWebGL = function(_super) {
__extends(MathBackendWebGL2, _super);
function MathBackendWebGL2(gpgpu) {
var _this = _super.call(this) || this;
_this.pendingRead = new WeakMap();
_this.pendingDisposal = new WeakSet();
_this.dataRefCount = new WeakMap();
_this.numBytesInGPU = 0;
_this.uploadWaitMs = 0;
_this.downloadWaitMs = 0;
_this.warnedAboutMemory = false;
_this.warnedAboutCPUBackend = false;
_this.pendingDeletes = 0;
_this.disposed = false;
if (!tf.env().getBool("HAS_WEBGL")) {
throw new Error("WebGL is not supported on this device");
}
if (gpgpu == null) {
var gl = getWebGLContext(tf.env().getNumber("WEBGL_VERSION"));
_this.binaryCache = getBinaryCache(tf.env().getNumber("WEBGL_VERSION"));
_this.gpgpu = new GPGPUContext(gl);
_this.canvas = gl.canvas;
_this.gpgpuCreatedLocally = true;
} else {
_this.gpgpu = gpgpu;
_this.binaryCache = {};
_this.gpgpuCreatedLocally = false;
_this.canvas = gpgpu.gl.canvas;
}
_this.textureManager = new TextureManager(_this.gpgpu);
_this.numMBBeforeWarning = numMBBeforeWarning();
_this.texData = new tf.DataStorage(_this, tf.engine());
return _this;
}
MathBackendWebGL2.prototype.numDataIds = function() {
return this.texData.numDataIds() + (this.cpuBackend ? this.cpuBackend.numDataIds() : 0) - this.pendingDeletes;
};
MathBackendWebGL2.prototype.write = function(values, shape, dtype) {
if (tf.env().getBool("WEBGL_CHECK_NUMERICAL_PROBLEMS") || tf.env().getBool("DEBUG")) {
this.checkNumericalProblems(values);
}
if (dtype === "complex64" && values != null) {
throw new Error("Cannot write to a complex64 dtype. Please use tf.complex(real, imag).");
}
var dataId = {};
this.texData.set(dataId, {
shape,
dtype,
values,
usage: TextureUsage.UPLOAD,
refCount: 1,
complexParentRefCount: 0
});
return dataId;
};
MathBackendWebGL2.prototype.incRef = function(dataId) {
var texData = this.texData.get(dataId);
texData.refCount++;
};
MathBackendWebGL2.prototype.decRef = function(dataId) {
if (this.texData.has(dataId)) {
var texData = this.texData.get(dataId);
texData.refCount--;
}
};
MathBackendWebGL2.prototype.move = function(dataId, values, shape, dtype) {
if (tf.env().getBool("DEBUG")) {
this.checkNumericalProblems(values);
}
if (dtype === "complex64") {
throw new Error("Cannot write to a complex64 dtype. Please use tf.complex(real, imag).");
}
this.texData.set(dataId, {
shape,
dtype,
values,
usage: TextureUsage.UPLOAD,
refCount: 1,
complexParentRefCount: 0
});
};
MathBackendWebGL2.prototype.disposeIntermediateTensorInfo = function(tensorInfo) {
var dataId = tensorInfo.dataId;
if (this.texData.has(dataId)) {
var textureData = this.texData.get(dataId);
textureData.refCount--;
if (textureData.refCount < 1) {
this.disposeData(dataId);
}
}
};
MathBackendWebGL2.prototype.readSync = function(dataId) {
var texData = this.texData.get(dataId);
var values = texData.values, dtype = texData.dtype, complexTensorInfos = texData.complexTensorInfos, slice = texData.slice, shape = texData.shape, isPacked = texData.isPacked;
if (slice != null) {
var program = void 0;
if (isPacked) {
program = new UnaryOpPackedProgram(shape, CLONE);
} else {
program = new UnaryOpProgram(shape, CLONE);
}
var res = this.runWebGLProgram(program, [{dataId, shape, dtype}], dtype);
var data = this.readSync(res.dataId);
this.disposeIntermediateTensorInfo(res);
return data;
}
if (values != null) {
return this.convertAndCacheOnCPU(dataId);
}
if (dtype === "string") {
return values;
}
var shouldTimeProgram = this.activeTimers != null;
var start;
if (shouldTimeProgram) {
start = tf.util.now();
}
var result;
if (dtype === "complex64") {
var realValues = this.readSync(complexTensorInfos.real.dataId);
var imagValues = this.readSync(complexTensorInfos.imag.dataId);
result = tf.backend_util.mergeRealAndImagArrays(realValues, imagValues);
} else {
result = this.getValuesFromTexture(dataId);
}
if (shouldTimeProgram) {
this.downloadWaitMs += tf.util.now() - start;
}
return this.convertAndCacheOnCPU(dataId, result);
};
MathBackendWebGL2.prototype.read = function(dataId) {
return __awaiter(this, void 0, void 0, function() {
var subscribers_1, texData, values, shape, slice, dtype, complexTensorInfos, isPacked, program, res, data, buffer, tmpDownloadTarget, tmpData, vals, ps, realValues, imagValues, size, dTypeVals, subscribers;
var _a;
return __generator(this, function(_b) {
switch (_b.label) {
case 0:
if (this.pendingRead.has(dataId)) {
subscribers_1 = this.pendingRead.get(dataId);
return [2, new Promise(function(resolve) {
return subscribers_1.push(resolve);
})];
}
texData = this.texData.get(dataId);
values = texData.values, shape = texData.shape, slice = texData.slice, dtype = texData.dtype, complexTensorInfos = texData.complexTensorInfos, isPacked = texData.isPacked;
if (slice != null) {
program = void 0;
if (isPacked) {
program = new UnaryOpPackedProgram(shape, CLONE);
} else {
program = new UnaryOpProgram(shape, CLONE);
}
res = this.runWebGLProgram(program, [{dataId, shape, dtype}], dtype);
data = this.read(res.dataId);
this.disposeIntermediateTensorInfo(res);
return [2, data];
}
if (values != null) {
return [2, this.convertAndCacheOnCPU(dataId)];
}
if (!tf.env().getBool("WEBGL_DOWNLOAD_FLOAT_ENABLED") && tf.env().getNumber("WEBGL_VERSION") === 2) {
throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and WEBGL_VERSION=2 not yet supported.");
}
buffer = null;
if (dtype !== "complex64" && tf.env().get("WEBGL_BUFFER_SUPPORTED")) {
tmpDownloadTarget = this.decode(dataId);
tmpData = this.texData.get(tmpDownloadTarget.dataId);
buffer = (_a = this.gpgpu).createBufferFromTexture.apply(_a, [tmpData.texture].concat(getDenseTexShape(shape)));
}
this.pendingRead.set(dataId, []);
if (!(dtype !== "complex64"))
return [3, 2];
return [4, this.gpgpu.createAndWaitForFence()];
case 1:
_b.sent();
_b.label = 2;
case 2:
if (!(dtype === "complex64"))
return [3, 4];
return [4, Promise.all([
this.read(complexTensorInfos.real.dataId),
this.read(complexTensorInfos.imag.dataId)
])];
case 3:
ps = _b.sent();
realValues = ps[0];
imagValues = ps[1];
vals = tf.backend_util.mergeRealAndImagArrays(realValues, imagValues);
return [3, 5];
case 4:
if (buffer == null) {
vals = this.getValuesFromTexture(dataId);
} else {
size = tf.util.sizeFromShape(shape);
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
}
_b.label = 5;
case 5:
if (tmpDownloadTarget != null) {
this.disposeIntermediateTensorInfo(tmpDownloadTarget);
}
dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
subscribers = this.pendingRead.get(dataId);
this.pendingRead.delete(dataId);
subscribers.forEach(function(resolve) {
return resolve(dTypeVals);
});
if (this.pendingDisposal.has(dataId)) {
this.pendingDisposal.delete(dataId);
this.disposeData(dataId);
this.pendingDeletes--;
}
return [2, dTypeVals];
}
});
});
};
MathBackendWebGL2.prototype.checkNumericalProblems = function(values) {
if (values == null) {
return;
}
for (var i = 0; i < values.length; i++) {
var num = values[i];
if (!canBeRepresented(num)) {
if (tf.env().getBool("WEBGL_RENDER_FLOAT32_CAPABLE")) {
throw Error("The value " + num + " cannot be represented with your current settings. Consider enabling float32 rendering: 'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'");
}
throw Error("The value " + num + " cannot be represented on this device.");
}
}
};
MathBackendWebGL2.prototype.getValuesFromTexture = function(dataId) {
var _a;
var _b = this.texData.get(dataId), shape = _b.shape, dtype = _b.dtype, isPacked = _b.isPacked;
var size = tf.util.sizeFromShape(shape);
if (tf.env().getBool("WEBGL_DOWNLOAD_FLOAT_ENABLED")) {
var tmpTarget = this.decode(dataId);
var tmpData_1 = this.texData.get(tmpTarget.dataId);
var vals_1 = (_a = this.gpgpu).downloadMatrixFromPackedTexture.apply(_a, [tmpData_1.texture].concat(getDenseTexShape(shape))).subarray(0, size);
this.disposeIntermediateTensorInfo(tmpTarget);
return vals_1;
}
var shouldUsePackedProgram = tf.env().getBool("WEBGL_PACK") && isPacked === true;
var outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
var program = shouldUsePackedProgram ? new EncodeFloatPackedProgram(outputShape) : new EncodeFloatProgram(outputShape);
var output = this.runWebGLProgram(program, [{shape: outputShape, dtype, dataId}], "float32");
var tmpData = this.texData.get(output.dataId);
var vals = this.gpgpu.downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture, tmpData.texShape[0], tmpData.texShape[1]).subarray(0, size);
this.disposeIntermediateTensorInfo(output);
return vals;
};
MathBackendWebGL2.prototype.time = function(f) {
return __awaiter(this, void 0, void 0, function() {
var oldActiveTimers, newActiveTimers, outerMostTime, flattenedActiveTimerQueries, flattenedActiveTimerNames, res, kernelMs_1;
return __generator(this, function(_a) {
switch (_a.label) {
case 0:
oldActiveTimers = this.activeTimers;
newActiveTimers = [];
outerMostTime = false;
if (this.programTimersStack == null) {
this.programTimersStack = newActiveTimers;
outerMostTime = true;
} else {
this.activeTimers.push(newActiveTimers);
}
this.activeTimers = newActiveTimers;
f();
flattenedActiveTimerQueries = tf.util.flatten(this.activeTimers.map(function(d) {
return d.query;
})).filter(function(d) {
return d != null;
});
flattenedActiveTimerNames = tf.util.flatten(this.activeTimers.map(function(d) {
return d.name;
})).filter(function(d) {
return d != null;
});
this.activeTimers = oldActiveTimers;
if (outerMostTime) {
this.programTimersStack = null;
}
res = {
uploadWaitMs: this.uploadWaitMs,
downloadWaitMs: this.downloadWaitMs,
kernelMs: null,
wallMs: null
};
if (!(tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE") > 0))
return [3, 2];
return [4, Promise.all(flattenedActiveTimerQueries)];
case 1:
kernelMs_1 = _a.sent();
res["kernelMs"] = tf.util.sum(kernelMs_1);
res["getExtraProfileInfo"] = function() {
return kernelMs_1.map(function(d, i) {
return {name: flattenedActiveTimerNames[i], ms: d};
}).map(function(d) {
return d.name + ": " + d.ms;
}).join(", ");
};
return [3, 3];
case 2:
res["kernelMs"] = {
error: "WebGL query timers are not supported in this environment."
};
_a.label = 3;
case 3:
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
return [2, res];
}
});
});
};
MathBackendWebGL2.prototype.memory = function() {
return {
unreliable: false,
numBytesInGPU: this.numBytesInGPU,
numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
numBytesInGPUFree: this.textureManager.numBytesFree
};
};
MathBackendWebGL2.prototype.startTimer = function() {
if (tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE") > 0) {
return this.gpgpu.beginQuery();
}
return {startMs: tf.util.now(), endMs: null};
};
MathBackendWebGL2.prototype.endTimer = function(query) {
if (tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE") > 0) {
this.gpgpu.endQuery();
return query;
}
query.endMs = tf.util.now();
return query;
};
MathBackendWebGL2.prototype.getQueryTime = function(query) {
return __awaiter(this, void 0, void 0, function() {
var timerQuery;
return __generator(this, function(_a) {
if (tf.env().getNumber("WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE") > 0) {
return [2, this.gpgpu.waitForQueryAndGetTime(query)];
}
timerQuery = query;
return [2, timerQuery.endMs - timerQuery.startMs];
});
});
};
MathBackendWebGL2.prototype.disposeData = function(dataId) {
if (this.pendingDisposal.has(dataId)) {
return;
}
if (this.pendingRead.has(dataId)) {
this.pendingDisposal.add(dataId);
this.pendingDeletes++;
return;
}
if (!this.texData.has(dataId)) {
return;
}
if (this.texData.get(dataId).complexParentRefCount > 0) {
this.texData.get(dataId).refCount--;
return;
}
this.releaseGPUData(dataId);
var complexTensorInfos = this.texData.get(dataId).complexTensorInfos;
if (complexTensorInfos != null) {
this.texData.get(complexTensorInfos.real.dataId).complexParentRefCount--;
this.disposeIntermediateTensorInfo(complexTensorInfos.real);
this.texData.get(complexTensorInfos.imag.dataId).complexParentRefCount--;
this.disposeIntermediateTensorInfo(complexTensorInfos.imag);
}
this.texData.delete(dataId);
};
MathBackendWebGL2.prototype.releaseGPUData = function(dataId) {
var _a = this.texData.get(dataId), texture = _a.texture, dtype = _a.dtype, texShape = _a.texShape, usage = _a.usage, isPacked = _a.isPacked, slice = _a.slice;
var key = slice && slice.origDataId || dataId;
var refCount = this.dataRefCount.get(key);
if (refCount > 1) {
this.dataRefCount.set(key, refCount - 1);
} else {
this.dataRefCount.delete(key);
if (texture != null) {
this.numBytesInGPU -= this.computeBytes(texShape, dtype);
this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
}
}
var texData = this.texData.get(dataId);
texData.texture = null;
texData.texShape = null;
texData.isPacked = false;
texData.slice = null;
};
MathBackendWebGL2.prototype.getTexture = function(dataId) {
this.uploadToGPU(dataId);
return this.texData.get(dataId).texture;
};
MathBackendWebGL2.prototype.getDataInfo = function(dataId) {
return this.texData.get(dataId);
};
MathBackendWebGL2.prototype.getCPUBackend = function() {
if (!tf.env().getBool("WEBGL_CPU_FORWARD")) {
return null;
}
if (this.cpuBackend == null) {
this.cpuBackend = tf.engine().findBackend("cpu");
}
return this.cpuBackend;
};
MathBackendWebGL2.prototype.shouldExecuteOnCPU = function(inputs, sizeThreshold) {
var _this = this;
if (sizeThreshold === void 0) {
sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD;
}
var cpuBackend = this.getCPUBackend();
if (!this.warnedAboutCPUBackend && cpuBackend == null) {
console.warn("Your application contains ops that are small enough to be executed on the CPU backend, however the CPU backend cannot be found. Consider importing the CPU backend (@tensorflow/tfjs-backend-cpu) for better performance.");
this.warnedAboutCPUBackend = true;
}
return cpuBackend != null && inputs.every(function(input) {
return _this.texData.get(input.dataId).texture == null && tf.util.sizeFromShape(input.shape) < sizeThreshold;
});
};
MathBackendWebGL2.prototype.getGPGPUContext = function() {
return this.gpgpu;
};
MathBackendWebGL2.prototype.slice = function(x, begin, size) {
if (this.shouldExecuteOnCPU([x])) {
var outValues = sliceImplCPU(this.texData.get(x.dataId).values, begin, size, x.shape, x.dtype);
return this.makeOutput(size, x.dtype, outValues);
}
if (tf.util.sizeFromShape(size) === 0) {
return tf.tensor([], size, x.dtype);
}
var isPacked = this.texData.get(x.dataId).isPacked;
var isContinous = tf.slice_util.isSliceContinous(x.shape, begin, size);
if (isPacked || !isContinous) {
var program = tf.env().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new SlicePackedProgram(size) : new SliceProgram(size);
var customSetup = program.getCustomSetupFunc(begin);
return this.compileAndRun(program, [x], null, customSetup);
}
this.uploadToGPU(x.dataId);
return this.shallowSlice(x, begin, size);
};
MathBackendWebGL2.prototype.shallowSlice = function(x, begin, size) {
var xTexData = this.texData.get(x.dataId);
var t = this.makeOutput(size, x.dtype);
var newTexData = this.texData.get(t.dataId);
Object.assign(newTexData, xTexData);
newTexData.shape = size;
newTexData.dtype = x.dtype;
var flatOffset = tf.slice_util.computeFlatOffset(begin, x.strides);
if (xTexData.slice) {
flatOffset += xTexData.slice.flatOffset;
}
newTexData.slice = {
flatOffset,
origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
};
var refCount = this.dataRefCount.get(newTexData.slice.origDataId) || 1;
this.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
return t;
};
MathBackendWebGL2.prototype.stridedSlice = function(x, begin, end, strides) {
var _this = this;
var cpuRes = this.tryRunOnCpuOrThrow([x], function() {
return _this.cpuBackend.stridedSlice(x, begin, end, strides);
});
if (cpuRes) {
return cpuRes;
}
var outShape = tf.slice_util.computeOutShape(begin, end, strides);
if (outShape.some(function(axis) {
return axis === 0;
})) {
return tf.tensor([], outShape);
}
var program = new StridedSliceProgram(begin, strides, outShape);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.reverse = function(x, axis) {
var program = tf.env().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new ReversePackedProgram(x.shape, axis) : new ReverseProgram(x.shape, axis);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.neg = function(x) {
var _this = this;
var cpuRes = this.tryRunOnCpuOrThrow([x], function() {
return _this.cpuBackend.neg(x);
});
if (cpuRes) {
return cpuRes;
}
if (tf.env().getBool("WEBGL_PACK_UNARY_OPERATIONS")) {
return this.packedUnaryOp(x, NEG, x.dtype);
}
var program = new UnaryOpProgram(x.shape, NEG);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.batchMatMul = function(a, b, transposeA, transposeB) {
var outerShapeA = transposeA ? a.shape[2] : a.shape[1];
var outerShapeB = transposeB ? b.shape[1] : b.shape[2];
var sharedDim = transposeA ? a.shape[1] : a.shape[2];
var batch = Math.max(a.shape[0], b.shape[0]);
if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > MATMUL_SHARED_DIM_THRESHOLD) {
if (transposeA) {
a = tf.transpose(a, [0, 2, 1]);
}
if (transposeB) {
b = tf.transpose(b, [0, 2, 1]);
}
var a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1);
var axis = outerShapeB === 1 ? 2 : 1;
var b3D = outerShapeB === 1 ? b.as3D(batch, 1, sharedDim) : b;
var product = tf.mul(a3D, b3D);
return product.sum(axis, true);
}
var dtype = tf.upcastType(a.dtype, b.dtype);
var program = new MatMulPackedProgram(a.shape, b.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB);
return this.compileAndRun(program, [a, b], dtype);
};
MathBackendWebGL2.prototype.fusedBatchMatMul = function(_a) {
var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
var outerShapeA = transposeA ? a.shape[2] : a.shape[1];
var outerShapeB = transposeB ? b.shape[1] : b.shape[2];
var batch = Math.max(a.shape[0], b.shape[0]);
var dtype = tf.upcastType(a.dtype, b.dtype);
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
var program = new MatMulPackedProgram(a.shape, b.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights);
var inputs = [a, b];
if (bias) {
inputs.push(bias);
}
if (preluActivationWeights) {
inputs.push(preluActivationWeights);
}
return this.compileAndRun(program, inputs, dtype);
};
MathBackendWebGL2.prototype.localResponseNormalization4D = function(x, radius, bias, alpha, beta) {
var program = tf.env().getBool("WEBGL_PACK_NORMALIZATION") ? new LRNPackedProgram(x.shape, radius, bias, alpha, beta) : new LRNProgram(x.shape, radius, bias, alpha, beta);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.LRNGrad = function(dy, inputImage, outputImage, depthRadius, bias, alpha, beta) {
var program = new LRNGradProgram(inputImage.shape, depthRadius, bias, alpha, beta);
return this.compileAndRun(program, [inputImage, outputImage, dy]);
};
MathBackendWebGL2.prototype.tile = function(x, reps) {
if (x.dtype === "string") {
var data = this.readSync(x.dataId);
var decodedData = data.map(function(d) {
return tf.util.decodeString(d);
});
var buf = tf.buffer(x.shape, x.dtype, decodedData);
return tile(buf, reps);
}
var program = new TileProgram(x.shape, reps);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.pad = function(x, paddings, constantValue) {
var program = tf.env().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new PadPackedProgram(x.shape, paddings, constantValue) : new PadProgram(x.shape, paddings, constantValue);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.gather = function(x, indices, axis) {
var _this = this;
var cpuRes = this.tryRunOnCpuOrThrow([x, indices], function() {
return _this.cpuBackend.gather(x, indices, axis);
});
if (cpuRes) {
return cpuRes;
}
var program = new GatherProgram(x.shape, indices.size, axis);
return this.compileAndRun(program, [x, indices]);
};
MathBackendWebGL2.prototype.batchToSpaceND = function(x, blockShape, crops) {
tf.util.assert(x.rank <= 4, function() {
return "batchToSpaceND for rank > 4 with a WebGL backend not implemented yet";
});
var prod = blockShape.reduce(function(a, b) {
return a * b;
});
var reshaped = tf.backend_util.getReshaped(x.shape, blockShape, prod);
var permuted = tf.backend_util.getPermuted(reshaped.length, blockShape.length);
var reshapedPermuted = tf.backend_util.getReshapedPermuted(x.shape, blockShape, prod);
var sliceBeginCoords = tf.backend_util.getSliceBeginCoords(crops, blockShape.length);
var sliceSize = tf.backend_util.getSliceSize(reshapedPermuted, crops, blockShape.length);
return tf.transpose(x.reshape(reshaped), permuted).reshape(reshapedPermuted).slice(sliceBeginCoords, sliceSize);
};
MathBackendWebGL2.prototype.spaceToBatchND = function(x, blockShape, paddings) {
tf.util.assert(x.rank <= 4, function() {
return "spaceToBatchND for rank > 4 with a WebGL backend not implemented yet";
});
var prod = blockShape.reduce(function(a, b) {
return a * b;
});
var completePaddings = [[0, 0]];
completePaddings.push.apply(completePaddings, paddings);
for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
var paddedX = x.pad(completePaddings);
var reshapedPaddedShape = tf.backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
var permutedReshapedPaddedPermutation = tf.backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
var flattenShape = tf.backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
var paddedXT = tf.transpose(paddedX.reshape(reshapedPaddedShape), permutedReshapedPaddedPermutation);
return tf.reshape(paddedXT, flattenShape);
};
MathBackendWebGL2.prototype.reduce = function(x, reduceType, dtype) {
var batchSize = x.shape[0];
var inSize = x.shape[1];
var windowSize = tf.backend_util.computeOptimalWindowSize(inSize);
var outSize = Math.ceil(inSize / windowSize);
var reduceInfo = {windowSize, inSize, batchSize, outSize};
var program = new ReduceProgram(reduceInfo, reduceType);
var output = this.compileAndRun(program, [x], dtype);
if (output.shape[1] === 1) {
return output;
}
return this.reduce(output, reduceType, dtype);
};
MathBackendWebGL2.prototype.argReduce = function(x, reduceType, bestIndicesA) {
if (bestIndicesA === void 0) {
bestIndicesA = null;
}
var batchSize = x.shape[0];
var inSize = x.shape[1];
if (bestIndicesA != null) {
batchSize = bestIndicesA.shape[0];
inSize = bestIndicesA.shape[1];
}
var windowSize = tf.backend_util.computeOptimalWindowSize(inSize);
var reduceInfo = {
windowSize,
inSize,
batchSize,
outSize: Math.ceil(inSize / windowSize)
};
var program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
var inputs = [x];
if (bestIndicesA != null) {
inputs.push(bestIndicesA);
}
var output = this.compileAndRun(program, inputs, "int32");
if (output.shape[1] === 1) {
return output;
}
return this.argReduce(x, reduceType, output);
};
MathBackendWebGL2.prototype.argReducePacked = function(x, reduceType, bestIndicesA) {
if (bestIndicesA === void 0) {
bestIndicesA = null;
}
var inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
var inSize = inShape[inShape.length - 1];
var windowSize = tf.backend_util.computeOptimalWindowSize(inSize);
var program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
var inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
var output = this.compileAndRun(program, inputs, "int32");
if (output.rank === x.rank) {
return this.argReducePacked(x, reduceType, output);
}
return output;
};
MathBackendWebGL2.prototype.sum = function(x, axes) {
tf.backend_util.assertAxesAreInnerMostDims("sum", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = tf.util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
var outputDType = tf.sumOutType(x.dtype);
return this.reduce(a2D, "sum", outputDType).reshape(outShape);
};
MathBackendWebGL2.prototype.prod = function(x, axes) {
var _this = this;
var cpuRes = this.tryRunOnCpuOrThrow([x], function() {
return _this.cpuBackend.prod(x, axes);
});
if (cpuRes) {
return cpuRes;
}
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = tf.util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
var outputDType = tf.sumOutType(x.dtype);
return this.reduce(a2D, "prod", outputDType).reshape(outShape);
};
MathBackendWebGL2.prototype.unsortedSegmentSum = function(x, segmentIds, numSegments) {
var axis = 0;
var permutation = tf.backend_util.getAxesPermutation([axis], x.rank);
var permutedX = x;
if (permutation != null) {
permutedX = tf.transpose(x, permutation);
axis = tf.backend_util.getInnerMostAxes(1, x.rank)[0];
}
var outShape = segment_util.computeOutShape(permutedX.shape, axis, numSegments);
var inSize = tf.util.sizeFromShape([permutedX.shape[axis]]);
var a2D = permutedX.as2D(-1, inSize);
var outputDType = tf.sumOutType(x.dtype);
var result = this.segOpCompute(a2D, "unsortedSegmentSum", segmentIds, outputDType, numSegments).reshape(outShape);
if (permutation != null) {
result = tf.transpose(result, tf.backend_util.getUndoAxesPermutation(permutation));
}
return result;
};
MathBackendWebGL2.prototype.segOpCompute = function(x, segOpType, segmentIds, dtype, numSegments) {
var batchSize = x.shape[0];
var inSize = x.shape[1];
var windowSize = segment_util.segOpComputeOptimalWindowSize(inSize, numSegments);
var segOpInfo = {windowSize, inSize, batchSize, numSegments};
var program = new SegmentOpProgram(segOpInfo, segOpType);
var output = this.compileAndRun(program, [x, segmentIds], dtype);
if (output.shape[1] === numSegments) {
return output;
}
segmentIds = tf.range(0, numSegments).tile([inSize / windowSize]);
return this.segOpCompute(output, segOpType, segmentIds, dtype, numSegments);
};
MathBackendWebGL2.prototype.argMinMaxReduce = function(x, axis, reduceType) {
var axes = [axis];
tf.backend_util.assertAxesAreInnerMostDims("arg" + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.rank);
if (!tf.env().getBool("WEBGL_PACK_REDUCE") || x.rank <= 2) {
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = tf.util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
return this.argReduce(a2D, reduceType).reshape(outShape);
}
return this.argReducePacked(x, reduceType);
};
MathBackendWebGL2.prototype.argMin = function(x, axis) {
return this.argMinMaxReduce(x, axis, "min");
};
MathBackendWebGL2.prototype.argMax = function(x, axis) {
return this.argMinMaxReduce(x, axis, "max");
};
MathBackendWebGL2.prototype.cumsum = function(x, axis, exclusive, reverse) {
if (axis !== x.rank - 1) {
throw new Error("WebGL cumsum shader expects an inner-most axis=" + (x.rank - 1) + " " + ("but got axis=" + axis));
}
var size = x.shape[axis];
var result = x;
for (var i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
var program = new CumSumProgram(x.shape, false, reverse);
var customSetup = program.getCustomSetupFunc(i);
var prevResult = result;
result = this.compileAndRun(program, [result], result.dtype, customSetup);
prevResult.dispose();
}
if (exclusive) {
var program = new CumSumProgram(x.shape, exclusive, reverse);
var prevResult = result;
result = this.compileAndRun(program, [result]);
prevResult.dispose();
}
return result;
};
MathBackendWebGL2.prototype.equal = function(a, b) {
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
return this.packedBinaryOp(a, b, EQUAL$1, "bool");
}
var program = new BinaryOpProgram(EQUAL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], "bool");
};
MathBackendWebGL2.prototype.less = function(a, b) {
var _this = this;
var cpuRes = this.tryRunOnCpuOrThrow([a, b], function() {
return _this.cpuBackend.less(a, b);
});
if (cpuRes) {
return cpuRes;
}
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
return this.packedBinaryOp(a, b, LESS$1, "bool");
}
var program = new BinaryOpProgram(LESS, a.shape, b.shape);
return this.compileAndRun(program, [a, b], "bool");
};
MathBackendWebGL2.prototype.lessEqual = function(a, b) {
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
return this.packedBinaryOp(a, b, LESS_EQUAL$1, "bool");
}
var program = new BinaryOpProgram(LESS_EQUAL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], "bool");
};
MathBackendWebGL2.prototype.greater = function(a, b) {
var _this = this;
var cpuRes = this.tryRunOnCpuOrThrow([a, b], function() {
return _this.cpuBackend.greater(a, b);
});
if (cpuRes) {
return cpuRes;
}
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
return this.packedBinaryOp(a, b, GREATER$1, "bool");
}
var program = new BinaryOpProgram(GREATER, a.shape, b.shape);
return this.compileAndRun(program, [a, b], "bool");
};
MathBackendWebGL2.prototype.greaterEqual = function(a, b) {
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
return this.packedBinaryOp(a, b, GREATER_EQUAL$1, "bool");
}
var program = new BinaryOpProgram(GREATER_EQUAL, a.shape, b.shape);
return this.compileAndRun(program, [a, b], "bool");
};
MathBackendWebGL2.prototype.logicalNot = function(x) {
var program = new UnaryOpProgram(x.shape, LOGICAL_NOT);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.logicalAnd = function(a, b) {
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
return this.packedBinaryOp(a, b, LOGICAL_AND$1, "bool");
}
var program = new BinaryOpProgram(LOGICAL_AND, a.shape, b.shape);
return this.compileAndRun(program, [a, b], "bool");
};
MathBackendWebGL2.prototype.logicalOr = function(a, b) {
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
return this.packedBinaryOp(a, b, LOGICAL_OR$1, "bool");
}
var program = new BinaryOpProgram(LOGICAL_OR, a.shape, b.shape);
return this.compileAndRun(program, [a, b], "bool");
};
MathBackendWebGL2.prototype.select = function(condition, a, b) {
var program = new SelectProgram(condition.rank, a.shape, a.rank);
return this.compileAndRun(program, [condition, a, b], tf.upcastType(a.dtype, b.dtype));
};
MathBackendWebGL2.prototype.where = function(condition) {
tf.backend_util.warn("tf.where() in webgl locks the UI thread. Call tf.whereAsync() instead");
var condVals = condition.dataSync();
return whereImpl(condition.shape, condVals);
};
MathBackendWebGL2.prototype.topk = function(x, k, sorted) {
var xVals = x.dataSync();
return topkImpl(xVals, x.shape, x.dtype, k, sorted);
};
MathBackendWebGL2.prototype.min = function(x, axes) {
tf.backend_util.assertAxesAreInnerMostDims("min", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = tf.util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
return this.reduce(a2D, "min", a2D.dtype).reshape(outShape);
};
MathBackendWebGL2.prototype.minimum = function(a, b) {
var _this = this;
var cpuRes = this.tryRunOnCpuOrThrow([a, b], function() {
return _this.cpuBackend.minimum(a, b);
});
if (cpuRes) {
return cpuRes;
}
var program = tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new BinaryOpPackedProgram(MIN$1, a.shape, b.shape) : new BinaryOpProgram(MIN, a.shape, b.shape);
return this.compileAndRun(program, [a, b]);
};
MathBackendWebGL2.prototype.mod = function(a, b) {
var program = tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new BinaryOpPackedProgram(MOD$1, a.shape, b.shape) : new BinaryOpProgram(MOD, a.shape, b.shape);
return this.compileAndRun(program, [a, b]);
};
MathBackendWebGL2.prototype.maximum = function(a, b) {
var _this = this;
var cpuRes = this.tryRunOnCpuOrThrow([a, b], function() {
return _this.cpuBackend.maximum(a, b);
});
if (cpuRes) {
return cpuRes;
}
var program = tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new BinaryOpPackedProgram(MAX$1, a.shape, b.shape) : new BinaryOpProgram(MAX, a.shape, b.shape);
return this.compileAndRun(program, [a, b]);
};
MathBackendWebGL2.prototype.all = function(x, axes) {
tf.backend_util.assertAxesAreInnerMostDims("all", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = tf.util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
return this.reduce(a2D, "all", a2D.dtype).reshape(outShape);
};
MathBackendWebGL2.prototype.any = function(x, axes) {
tf.backend_util.assertAxesAreInnerMostDims("any", axes, x.rank);
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
var inSize = tf.util.sizeFromShape(reduceShape);
var a2D = x.as2D(-1, inSize);
return this.reduce(a2D, "any", a2D.dtype).reshape(outShape);
};
MathBackendWebGL2.prototype.floorDiv = function(a, b) {
var op = INT_DIV;
var outputDtype = "int32";
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
return this.packedBinaryOp(a, b, INT_DIV$1, outputDtype);
}
var program = new BinaryOpProgram(op, a.shape, b.shape);
return this.compileAndRun(program, [a, b], outputDtype);
};
MathBackendWebGL2.prototype.packedUnaryOp = function(x, op, dtype) {
var program = new UnaryOpPackedProgram(x.shape, op);
return this.compileAndRun(program, [x], dtype);
};
MathBackendWebGL2.prototype.packedBinaryOp = function(a, b, op, dtype, checkOutOfBounds) {
if (checkOutOfBounds === void 0) {
checkOutOfBounds = false;
}
var program = new BinaryOpPackedProgram(op, a.shape, b.shape, checkOutOfBounds);
return this.compileAndRun(program, [a, b], dtype);
};
MathBackendWebGL2.prototype.makeComplexComponentTensorInfo = function(complexTensor, complexPart) {
return {
dataId: complexPart.dataId,
dtype: complexPart.dtype,
shape: complexTensor.shape
};
};
MathBackendWebGL2.prototype.addN = function(tensors) {
if (tensors.length === 1) {
return tensors[0];
}
if (tensors.length > tf.env().get("WEBGL_MAX_TEXTURES_IN_SHADER")) {
var midIndex = Math.floor(tensors.length / 2);
var leftSide = this.addN(tensors.slice(0, midIndex));
var rightSide = this.addN(tensors.slice(midIndex));
return this.addN([leftSide, rightSide]);
}
var dtype = tensors.map(function(t) {
return t.dtype;
}).reduce(function(d1, d2) {
return tf.upcastType(d1, d2);
});
var shapes = tensors.map(function(t) {
return t.shape;
});
var usePackedOp = tf.env().getBool("WEBGL_PACK");
var program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes);
return this.compileAndRun(program, tensors, dtype);
};
MathBackendWebGL2.prototype.pow = function(a, b) {
var usePackedOp = tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS");
var program = usePackedOp ? new BinaryOpPackedProgram(POW$1, a.shape, b.shape) : new BinaryOpProgram(POW, a.shape, b.shape);
var dtype = tf.upcastType(a.dtype, b.dtype);
return this.compileAndRun(program, [a, b], dtype);
};
MathBackendWebGL2.prototype.ceil = function(x) {
if (this.shouldExecuteOnCPU([x])) {
var outValues = ceilImplCPU(this.texData.get(x.dataId).values, x.dtype);
return this.makeOutput(x.shape, x.dtype, outValues);
}
if (tf.env().getBool("WEBGL_PACK_UNARY_OPERATIONS")) {
return this.packedUnaryOp(x, CEIL, x.dtype);
}
var program = new UnaryOpProgram(x.shape, CEIL);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.floor = function(x) {
if (this.shouldExecuteOnCPU([x])) {
var outValues = floorImplCPU(this.texData.get(x.dataId).values, x.dtype);
return this.makeOutput(x.shape, x.dtype, outValues);
}
if (tf.env().getBool("WEBGL_PACK_UNARY_OPERATIONS")) {
return this.packedUnaryOp(x, FLOOR, x.dtype);
}
var program = new UnaryOpProgram(x.shape, FLOOR);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.sign = function(x) {
var program = new UnaryOpProgram(x.shape, SIGN);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.isNaN = function(x) {
var program = new UnaryOpProgram(x.shape, IS_NAN);
return this.compileAndRun(program, [x], "bool");
};
MathBackendWebGL2.prototype.isInf = function(x) {
var program = new UnaryOpProgram(x.shape, IS_INF);
return this.compileAndRun(program, [x], "bool");
};
MathBackendWebGL2.prototype.isFinite = function(x) {
var program = new UnaryOpProgram(x.shape, IS_FINITE);
return this.compileAndRun(program, [x], "bool");
};
MathBackendWebGL2.prototype.round = function(x) {
var program = new UnaryOpProgram(x.shape, ROUND);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.exp = function(x) {
if (this.shouldExecuteOnCPU([x])) {
var outValues = expImplCPU(this.texData.get(x.dataId).values, x.dtype);
return this.makeOutput(x.shape, x.dtype, outValues);
}
if (tf.env().getBool("WEBGL_PACK_UNARY_OPERATIONS")) {
return this.packedUnaryOp(x, EXP, x.dtype);
}
var program = new UnaryOpProgram(x.shape, EXP);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.expm1 = function(x) {
if (this.shouldExecuteOnCPU([x])) {
var outValues = expm1ImplCPU(this.texData.get(x.dataId).values, x.dtype);
return this.makeOutput(x.shape, x.dtype, outValues);
}
if (tf.env().getBool("WEBGL_PACK_UNARY_OPERATIONS")) {
return this.packedUnaryOp(x, EXPM1, x.dtype);
}
var program = new UnaryOpProgram(x.shape, EXPM1);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.softmax = function(logits, dim) {
var axes = tf.util.parseAxisParam([dim], logits.shape);
var maxLogit = tf.max(logits, axes);
var expandedShape = tf.backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
var a = tf.sub(logits, maxLogit.reshape(expandedShape));
var b = this.exp(a);
var sumExp = this.sum(b, axes).reshape(expandedShape);
return tf.div(b, sumExp);
};
MathBackendWebGL2.prototype.log = function(x) {
if (this.shouldExecuteOnCPU([x])) {
var outValues = logImplCPU(this.texData.get(x.dataId).values, x.dtype);
return this.makeOutput(x.shape, x.dtype, outValues);
}
if (tf.env().getBool("WEBGL_PACK_UNARY_OPERATIONS")) {
return this.packedUnaryOp(x, LOG$1, x.dtype);
}
var program = new UnaryOpProgram(x.shape, LOG);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.log1p = function(x) {
var program = new UnaryOpProgram(x.shape, LOG1P);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.sqrt = function(x) {
var program = new UnaryOpProgram(x.shape, SQRT);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.rsqrt = function(x) {
if (this.shouldExecuteOnCPU([x])) {
var outValues = rsqrtImplCPU(this.texData.get(x.dataId).values, x.dtype);
return this.makeOutput(x.shape, x.dtype, outValues);
}
var program = new UnaryOpProgram(x.shape, RSQRT);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.reciprocal = function(x) {
var program = new UnaryOpProgram(x.shape, RECIPROCAL);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.relu = function(x) {
var program;
if (tf.env().getBool("WEBGL_PACK")) {
program = new UnaryOpPackedProgram(x.shape, RELU$1);
} else {
program = new UnaryOpProgram(x.shape, RELU);
}
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.relu6 = function(x) {
var program;
if (tf.env().getBool("WEBGL_PACK")) {
program = new UnaryOpPackedProgram(x.shape, RELU6$1);
} else {
program = new UnaryOpProgram(x.shape, RELU6);
}
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.prelu = function(x, alpha) {
var program = tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new BinaryOpPackedProgram(PRELU$1, x.shape, alpha.shape) : new BinaryOpProgram(PRELU, x.shape, alpha.shape);
return this.compileAndRun(program, [x, alpha]);
};
MathBackendWebGL2.prototype.elu = function(x) {
if (tf.env().getBool("WEBGL_PACK_UNARY_OPERATIONS")) {
return this.packedUnaryOp(x, ELU$1, x.dtype);
}
var program = new UnaryOpProgram(x.shape, ELU);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.eluDer = function(dy, y) {
var program = tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS") ? new BinaryOpPackedProgram(ELU_DER$1, dy.shape, y.shape) : new BinaryOpProgram(ELU_DER, dy.shape, y.shape);
return this.compileAndRun(program, [dy, y]);
};
MathBackendWebGL2.prototype.selu = function(x) {
var program = new UnaryOpProgram(x.shape, SELU);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.clip = function(x, min, max) {
var program;
if (tf.env().getBool("WEBGL_PACK_CLIP")) {
program = new ClipPackedProgram(x.shape);
} else {
program = new ClipProgram(x.shape);
}
var customSetup = program.getCustomSetupFunc(min, max);
return this.compileAndRun(program, [x], null, customSetup);
};
MathBackendWebGL2.prototype.abs = function(x) {
if (this.shouldExecuteOnCPU([x]) && x.dtype !== "complex64") {
var outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
return this.makeOutput(x.shape, x.dtype, outValues);
}
if (tf.env().getBool("WEBGL_PACK_UNARY_OPERATIONS")) {
return this.packedUnaryOp(x, ABS, x.dtype);
}
var program = new UnaryOpProgram(x.shape, ABS);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.complexAbs = function(x) {
var xData = this.texData.get(x.dataId);
var program = new ComplexAbsProgram(x.shape);
var inputs = [
this.makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real),
this.makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag)
];
return this.compileAndRun(program, inputs);
};
MathBackendWebGL2.prototype.sigmoid = function(x) {
var program = new UnaryOpProgram(x.shape, SIGMOID);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.softplus = function(x) {
var program = new UnaryOpProgram(x.shape, SOFTPLUS);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.asin = function(x) {
var program = new UnaryOpProgram(x.shape, ASIN);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.acos = function(x) {
var program = new UnaryOpProgram(x.shape, ACOS);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.atan = function(x) {
var program = new UnaryOpProgram(x.shape, ATAN);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.sinh = function(x) {
var program = new UnaryOpProgram(x.shape, SINH);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.cosh = function(x) {
var program = new UnaryOpProgram(x.shape, COSH);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.tanh = function(x) {
var program = new UnaryOpProgram(x.shape, TANH);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.asinh = function(x) {
var program = new UnaryOpProgram(x.shape, ASINH);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.acosh = function(x) {
var program = new UnaryOpProgram(x.shape, ACOSH);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.atanh = function(x) {
var program = new UnaryOpProgram(x.shape, ATANH);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.erf = function(x) {
var program = new UnaryOpProgram(x.shape, ERF);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.step = function(x, alpha) {
var program = new UnaryOpProgram(x.shape, STEP(alpha));
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.conv2dByMatMul = function(x, filter, convInfo, bias, activation, preluActivationWeights) {
var xShape = x.shape;
var xTexData = this.texData.get(x.dataId);
var sharedMatMulDim = convInfo.inChannels;
var outerShapeX = xShape[0] * xShape[1] * xShape[2];
var outerShapeFilter = convInfo.outChannels;
var isChannelsLast = convInfo.dataFormat === "channelsLast";
var transposeA = false;
var transposeB = false;
var batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) && sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
var reshapeWillBeExpensive = xShape[2] % 2 !== 0 && !!xTexData.isPacked;
if (batchMatMulWillBeUnpacked || !tf.env().getBool("WEBGL_LAZILY_UNPACK") || !tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS") || !reshapeWillBeExpensive) {
var targetShape_1 = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] : xShape[0] * xShape[2] * xShape[3];
var xReshaped_1 = tf.reshape(x, [1, targetShape_1, convInfo.inChannels]);
var filterReshaped_1 = tf.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);
var result = this.fusedBatchMatMul({
a: xReshaped_1,
b: filterReshaped_1,
transposeA,
transposeB,
bias,
activation,
preluActivationWeights
});
return tf.reshape(result, convInfo.outShape);
}
var targetShape = isChannelsLast ? xShape[0] * xShape[1] * (xShape[2] + 1) : xShape[0] * xShape[2] * (xShape[3] + 1);
var xReshaped = {
dataId: x.dataId,
shape: [1, targetShape, convInfo.inChannels],
dtype: x.dtype
};
var originalXTexDataShape = xTexData.shape;
xTexData.shape = xTexData.shape.slice();
xTexData.shape[xTexData.shape.length - 2]++;
tf.util.assert(isReshapeFree(xTexData.shape, xReshaped.shape), function() {
return "packed reshape " + xTexData.shape + " to " + xReshaped.shape + " isn't free";
});
var filterReshaped = tf.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]);
var pointwiseConv = this.fusedBatchMatMul({
a: xReshaped,
b: filterReshaped,
transposeA,
transposeB,
bias,
activation,
preluActivationWeights
});
var pointwiseConvTexData = this.texData.get(pointwiseConv.dataId);
tf.util.assert(pointwiseConvTexData.isPacked, function() {
return "batchMatMul result is expected to be packed";
});
xTexData.shape = originalXTexDataShape;
pointwiseConvTexData.shape = convInfo.outShape;
return tf.engine().makeTensorFromDataId(pointwiseConv.dataId, convInfo.outShape, pointwiseConv.dtype);
};
MathBackendWebGL2.prototype.conv2dWithIm2Row = function(x, filter, convInfo, bias, activation, preluActivationWeights) {
var filterWidth = convInfo.filterWidth, filterHeight = convInfo.filterHeight, inChannels = convInfo.inChannels, outWidth = convInfo.outWidth, outHeight = convInfo.outHeight, dataFormat = convInfo.dataFormat;
var isChannelsLast = dataFormat === "channelsLast";
var sharedDim = filterWidth * filterHeight * inChannels;
var numCols = outHeight * outWidth;
var x2ColShape = [sharedDim, numCols];
var transposeA = true;
var transposeB = false;
var xSqueezed = x.squeeze([0]);
var w2Row = filter.reshape([1, sharedDim, -1]);
var im2ColProgram = new Im2ColPackedProgram(x2ColShape, xSqueezed.shape, convInfo);
var im2Col = this.compileAndRun(im2ColProgram, [xSqueezed]).reshape([
1,
x2ColShape[0],
x2ColShape[1]
]);
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
var matmulProgram = new MatMulPackedProgram(im2Col.shape, w2Row.shape, [1, numCols, convInfo.outChannels], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights);
var inputs = [im2Col, w2Row];
if (bias) {
inputs.push(bias);
}
if (hasPreluActivationWeights) {
inputs.push(preluActivationWeights);
}
var product = this.compileAndRun(matmulProgram, inputs);
if (isChannelsLast) {
return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
} else {
return product.reshape([1, convInfo.outChannels, outHeight, outWidth]);
}
};
MathBackendWebGL2.prototype.fusedConv2d = function(_a) {
var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === "SAME" || convInfo.padInfo.type === "VALID")) {
return this.conv2dByMatMul(input, filter, convInfo, bias, activation, preluActivationWeights);
}
if (tf.env().getBool("WEBGL_CONV_IM2COL") && input.shape[0] === 1) {
return this.conv2dWithIm2Row(input, filter, convInfo, bias, activation, preluActivationWeights);
}
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
var fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
var program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
var inputs = [input, filter];
if (bias) {
inputs.push(bias);
}
if (preluActivationWeights) {
inputs.push(preluActivationWeights);
}
return this.compileAndRun(program, inputs);
};
MathBackendWebGL2.prototype.conv2d = function(x, filter, convInfo) {
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === "SAME" || convInfo.padInfo.type === "VALID")) {
return this.conv2dByMatMul(x, filter, convInfo);
}
if (tf.env().getBool("WEBGL_CONV_IM2COL") && x.shape[0] === 1) {
return this.conv2dWithIm2Row(x, filter, convInfo);
}
var program = new Conv2DProgram(convInfo);
return this.compileAndRun(program, [x, filter]);
};
MathBackendWebGL2.prototype.conv2dDerInput = function(dy, filter, convInfo) {
var program = new Conv2DDerInputProgram(convInfo);
return this.compileAndRun(program, [dy, filter]);
};
MathBackendWebGL2.prototype.conv2dDerFilter = function(x, dy, convInfo) {
var program = new Conv2DDerFilterProgram(convInfo);
return this.compileAndRun(program, [x, dy]);
};
MathBackendWebGL2.prototype.fusedDepthwiseConv2D = function(_a) {
var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
var shouldPackDepthwiseConv = tf.env().getBool("WEBGL_PACK_DEPTHWISECONV") && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1;
var fusedActivation = activation ? mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : null;
var inputs = [input, filter];
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
if (hasBias) {
inputs.push(bias);
}
if (hasPreluActivationWeights) {
inputs.push(preluActivationWeights);
}
var program;
if (shouldPackDepthwiseConv) {
program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
return this.compileAndRun(program, inputs);
}
program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
return this.compileAndRun(program, inputs);
};
MathBackendWebGL2.prototype.depthwiseConv2D = function(x, filter, convInfo) {
var program;
if (tf.env().getBool("WEBGL_PACK_DEPTHWISECONV") && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) {
program = new DepthwiseConvPacked2DProgram(convInfo);
return this.compileAndRun(program, [x, filter]);
}
program = new DepthwiseConv2DProgram(convInfo);
return this.compileAndRun(program, [x, filter]);
};
MathBackendWebGL2.prototype.depthwiseConv2DDerInput = function(dy, filter, convInfo) {
var program = new DepthwiseConv2DDerInputProgram(convInfo);
return this.compileAndRun(program, [dy, filter]);
};
MathBackendWebGL2.prototype.depthwiseConv2DDerFilter = function(x, dy, convInfo) {
var program = new DepthwiseConv2DDerFilterProgram(convInfo);
return this.compileAndRun(program, [x, dy]);
};
MathBackendWebGL2.prototype.conv3d = function(x, filter, convInfo) {
var program = new Conv3DProgram(convInfo);
return this.compileAndRun(program, [x, filter]);
};
MathBackendWebGL2.prototype.conv3dDerInput = function(dy, filter, convInfo) {
var program = new Conv3DDerInputProgram(convInfo);
return this.compileAndRun(program, [dy, filter]);
};
MathBackendWebGL2.prototype.conv3dDerFilter = function(x, dy, convInfo) {
var program = new Conv3DDerFilterProgram(convInfo);
return this.compileAndRun(program, [x, dy]);
};
MathBackendWebGL2.prototype.unstack = function(x, axis) {
var num = x.shape[axis];
var outShape = new Array(x.rank - 1);
var outIndex = 0;
for (var i = 0; i < x.rank; i++) {
if (i !== axis) {
outShape[outIndex++] = x.shape[i];
}
}
var begin = new Array(x.rank).fill(0);
var size = x.shape.slice();
size[axis] = 1;
var res = new Array(num);
for (var i = 0; i < res.length; i++) {
begin[axis] = i;
res[i] = this.slice(x, begin, size).reshape(outShape);
}
return res;
};
MathBackendWebGL2.prototype.avgPool3d = function(x, convInfo) {
var program = new Pool3DProgram(convInfo, "avg", false);
return this.compileAndRun(program, [x], "float32");
};
MathBackendWebGL2.prototype.avgPool3dBackprop = function(dy, x, convInfo) {
var avgPool3dBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
return this.compileAndRun(avgPool3dBackpropProgram, [dy], x.dtype);
};
MathBackendWebGL2.prototype.maxPool3d = function(x, convInfo) {
var program = new Pool3DProgram(convInfo, "max", false);
return this.compileAndRun(program, [x], "float32");
};
MathBackendWebGL2.prototype.maxPool3dBackprop = function(dy, x, y, convInfo) {
var getPositions = true;
var maxPool3dPositionsProgram = new Pool3DProgram(convInfo, "max", getPositions);
var maxPool3dPositions = this.compileAndRun(maxPool3dPositionsProgram, [x]);
var maxPool3dBackPropProgram = new MaxPool3DBackpropProgram(convInfo);
var result = this.compileAndRun(maxPool3dBackPropProgram, [dy, maxPool3dPositions], x.dtype);
maxPool3dPositions.dispose();
return result;
};
MathBackendWebGL2.prototype.resizeBilinear = function(x, newHeight, newWidth, alignCorners) {
var program = tf.env().getBool("WEBGL_PACK_IMAGE_OPERATIONS") ? new ResizeBilinearPackedProgram(x.shape, newHeight, newWidth, alignCorners) : new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners);
return this.compileAndRun(program, [x], "float32");
};
MathBackendWebGL2.prototype.resizeBilinearBackprop = function(dy, x, alignCorners) {
var program = new ResizeBilinearBackpropProgram(dy, x, alignCorners);
return this.compileAndRun(program, [dy]);
};
MathBackendWebGL2.prototype.resizeNearestNeighbor = function(x, newHeight, newWidth, alignCorners) {
var program = new ResizeNearestNeighborProgram(x.shape, newHeight, newWidth, alignCorners);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.resizeNearestNeighborBackprop = function(dy, x, alignCorners) {
var program = new ResizeNearestNeigborBackpropProgram(dy, x, alignCorners);
return this.compileAndRun(program, [dy]);
};
MathBackendWebGL2.prototype.multinomial = function(logits, normalized, numSamples, seed) {
var probs = normalized ? logits : tf.softmax(logits);
var batchSize = probs.shape[0];
var numOutcomes = probs.shape[1];
var program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
var customSetup = program.getCustomSetupFunc(seed);
return this.compileAndRun(program, [probs], "int32", customSetup);
};
MathBackendWebGL2.prototype.oneHot = function(indices, depth, onValue, offValue) {
var program = new OneHotProgram(indices.size, depth, onValue, offValue);
return this.compileAndRun(program, [indices]);
};
MathBackendWebGL2.prototype.diag = function(x) {
var program = new DiagProgram(x.size);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.cropAndResize = function(image, boxes, boxIndex, cropSize, method, extrapolationValue) {
var program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
return this.compileAndRun(program, [image, boxes, boxIndex], "float32");
};
MathBackendWebGL2.prototype.depthToSpace = function(x, blockSize, dataFormat) {
tf.util.assert(blockSize > 1, function() {
return "blockSize should be > 1 for depthToSpace, but was: " + blockSize;
});
var batchSize = x.shape[0];
var inputHeight = dataFormat === "NHWC" ? x.shape[1] : x.shape[2];
var inputWidth = dataFormat === "NHWC" ? x.shape[2] : x.shape[3];
var inputDepth = dataFormat === "NHWC" ? x.shape[3] : x.shape[1];
var outputHeight = inputHeight * blockSize;
var outputWidth = inputWidth * blockSize;
var outputDepth = inputDepth / (blockSize * blockSize);
var outputShape = dataFormat === "NHWC" ? [batchSize, outputHeight, outputWidth, outputDepth] : [batchSize, outputDepth, outputHeight, outputWidth];
var program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
return this.compileAndRun(program, [x]);
};
MathBackendWebGL2.prototype.split = function(x, sizeSplits, axis) {
return split(x, sizeSplits, axis);
};
MathBackendWebGL2.prototype.scatterND = function(indices, updates, shape) {
var _a = tf.backend_util.calculateShapes(updates, indices, shape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize;
var flattenShape = [outputSize / sliceSize, sliceSize];
var flattenIndices = indices.reshape([numUpdates, sliceRank]);
var flattenX = updates.reshape([numUpdates, sliceSize]);
if (outputSize === 0) {
return tf.backend_util.reshapeTensor(tf.tensor([]), shape);
}
var defaultValue = tf.scalar(0);
var program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.rank, flattenX.rank, strides, flattenShape);
var res = this.compileAndRun(program, [flattenX, flattenIndices, defaultValue]);
return res.reshape(shape);
};
MathBackendWebGL2.prototype.sparseToDense = function(sparseIndices, sparseValues, outputShape, defaultValue) {
var _a = tf.backend_util.calculateShapes(sparseValues, sparseIndices, outputShape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, strides = _a.strides, outputSize = _a.outputSize;
var sumDupeIndices = false;
var program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.rank, sparseValues.rank, strides, [outputSize, 1], sumDupeIndices);
var res = this.compileAndRun(program, [sparseValues, sparseIndices, defaultValue]);
return res.reshape(outputShape);
};
MathBackendWebGL2.prototype.gatherND = function(x, indices) {
var indicesShape = indices.shape;
var sliceRank = indicesShape[indicesShape.length - 1];
var _a = tf.backend_util.prepareAndValidate(x, indices), resultShape = _a[0], numSlices = _a[1], sliceSize = _a[2], strides = _a[3];
var flattenIndices = indices.reshape([numSlices, sliceRank]);
var flattenX = x.reshape([x.size / sliceSize, sliceSize]);
var program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize]);
var res = this.compileAndRun(program, [flattenX, flattenIndices]);
return res.reshape(resultShape);
};
MathBackendWebGL2.prototype.fill = function(shape, value, dtype) {
dtype = dtype || tf.util.inferDtype(value);
if (dtype === "string") {
var values = tf.util.getArrayFromDType(dtype, tf.util.sizeFromShape(shape));
values.fill(value);
return tf.engine().makeTensor(values, shape, dtype, this);
} else {
var program = new FillProgram(shape, value);
var customSetup = program.getCustomSetupFunc(value);
return this.compileAndRun(program, [], dtype, customSetup);
}
};
MathBackendWebGL2.prototype.onesLike = function(x) {
if (x.dtype === "string") {
throw new Error("onesLike is not supported under string dtype");
} else {
return this.fill(x.shape, 1, x.dtype);
}
};
MathBackendWebGL2.prototype.zerosLike = function(x) {
return this.fill(x.shape, x.dtype === "string" ? "" : 0, x.dtype);
};
MathBackendWebGL2.prototype.linspace = function(start, stop, num) {
return tf.backend_util.linspaceImpl(start, stop, num);
};
MathBackendWebGL2.prototype.makeTensorInfo = function(shape, dtype, values) {
var dataId = this.write(values, shape, dtype);
this.texData.get(dataId).usage = null;
return {dataId, shape, dtype};
};
MathBackendWebGL2.prototype.makeOutput = function(shape, dtype, values) {
var dataId = this.makeTensorInfo(shape, dtype, values).dataId;
return tf.engine().makeTensorFromDataId(dataId, shape, dtype, this);
};
MathBackendWebGL2.prototype.unpackTensor = function(input) {
var program = new UnpackProgram(input.shape);
return this.runWebGLProgram(program, [input], input.dtype);
};
MathBackendWebGL2.prototype.packTensor = function(input) {
var program = new PackProgram(input.shape);
var preventEagerUnpackingOutput = true;
return this.runWebGLProgram(program, [input], input.dtype, null, preventEagerUnpackingOutput);
};
MathBackendWebGL2.prototype.packedReshape = function(input, afterShape) {
var input3DShape = [
getBatchDim(input.shape)
].concat(getRowsCols(input.shape));
var input3D = {
dtype: input.dtype,
shape: input3DShape,
dataId: input.dataId
};
var afterShapeAs3D = [
getBatchDim(afterShape)
].concat(getRowsCols(afterShape));
var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
var preventEagerUnpackingOfOutput = true;
var output = this.runWebGLProgram(program, [input3D], input.dtype, null, preventEagerUnpackingOfOutput);
return {dataId: output.dataId, shape: afterShape, dtype: output.dtype};
};
MathBackendWebGL2.prototype.decode = function(dataId) {
var texData = this.texData.get(dataId);
var isPacked = texData.isPacked, shape = texData.shape, dtype = texData.dtype;
var shapeAs3D = getShapeAs3D(shape);
var program;
if (isPacked) {
program = new DecodeMatrixPackedProgram(shapeAs3D);
} else {
program = new DecodeMatrixProgram(shapeAs3D);
}
var preventEagerUnpackingOfOutput = true;
var out = this.runWebGLProgram(program, [{shape: shapeAs3D, dtype, dataId}], dtype, null, preventEagerUnpackingOfOutput);
return {dtype, shape, dataId: out.dataId};
};
MathBackendWebGL2.prototype.runWebGLProgram = function(program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput) {
var _this = this;
if (preventEagerUnpackingOfOutput === void 0) {
preventEagerUnpackingOfOutput = false;
}
var output = this.makeTensorInfo(program.outputShape, outputDtype);
var outData = this.texData.get(output.dataId);
if (program.packedOutput) {
outData.isPacked = true;
}
if (program.outPackingScheme === PackingScheme.DENSE) {
var texelShape = getDenseTexShape(program.outputShape);
outData.texShape = texelShape.map(function(d) {
return d * 2;
});
}
if (program.outTexUsage != null) {
outData.usage = program.outTexUsage;
}
if (tf.util.sizeFromShape(output.shape) === 0) {
outData.values = tf.util.getTypedArrayFromDType(output.dtype, 0);
return output;
}
var dataToDispose = [];
var inputsData = inputs.map(function(input) {
if (input.dtype === "complex64") {
throw new Error("GPGPUProgram does not support complex64 input. For complex64 dtypes, please separate the program into real and imaginary parts.");
}
var texData = _this.texData.get(input.dataId);
if (texData.texture == null) {
if (!program.packedInputs && tf.util.sizeFromShape(input.shape) <= tf.env().getNumber("WEBGL_SIZE_UPLOAD_UNIFORM")) {
return {
shape: input.shape,
texData: null,
isUniform: true,
uniformValues: texData.values
};
}
if (program.packedInputs) {
texData.isPacked = true;
texData.shape = input.shape;
}
} else if (!!texData.isPacked !== !!program.packedInputs) {
input = texData.isPacked ? _this.unpackTensor(input) : _this.packTensor(input);
dataToDispose.push(input);
texData = _this.texData.get(input.dataId);
} else if (texData.isPacked && !isReshapeFree(texData.shape, input.shape)) {
var savedInput = input;
var targetShape = input.shape;
input.shape = texData.shape;
input = _this.packedReshape(input, targetShape);
dataToDispose.push(input);
texData = _this.texData.get(input.dataId);
savedInput.shape = targetShape;
}
_this.uploadToGPU(input.dataId);
return {shape: input.shape, texData, isUniform: false};
});
this.uploadToGPU(output.dataId);
var outputData = {shape: output.shape, texData: outData, isUniform: false};
var key = makeShaderKey(program, inputsData, outputData);
var binary = this.getAndSaveBinary(key, function() {
return compileProgram(_this.gpgpu, program, inputsData, outputData);
});
var shouldTimeProgram = this.activeTimers != null;
var query;
if (shouldTimeProgram) {
query = this.startTimer();
}
runProgram(this.gpgpu, binary, inputsData, outputData, customSetup);
dataToDispose.forEach(function(info) {
return _this.disposeIntermediateTensorInfo(info);
});
if (shouldTimeProgram) {
query = this.endTimer(query);
this.activeTimers.push({name: program.constructor.name, query: this.getQueryTime(query)});
}
if (!tf.env().getBool("WEBGL_LAZILY_UNPACK") && outData.isPacked && preventEagerUnpackingOfOutput === false) {
var unpacked = this.unpackTensor(output);
this.disposeIntermediateTensorInfo(output);
return unpacked;
}
return output;
};
MathBackendWebGL2.prototype.compileAndRun = function(program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput) {
if (preventEagerUnpackingOfOutput === void 0) {
preventEagerUnpackingOfOutput = false;
}
outputDtype = outputDtype || inputs[0].dtype;
var outInfo = this.runWebGLProgram(program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput);
return tf.engine().makeTensorFromDataId(outInfo.dataId, outInfo.shape, outInfo.dtype);
};
MathBackendWebGL2.prototype.getAndSaveBinary = function(key, getBinary) {
if (!(key in this.binaryCache)) {
this.binaryCache[key] = getBinary();
}
return this.binaryCache[key];
};
MathBackendWebGL2.prototype.getTextureManager = function() {
return this.textureManager;
};
MathBackendWebGL2.prototype.dispose = function() {
var _this = this;
if (this.disposed) {
return;
}
if (!tf.env().getBool("IS_TEST")) {
var allKeys = Object.keys(this.binaryCache);
allKeys.forEach(function(key) {
_this.gpgpu.deleteProgram(_this.binaryCache[key].webGLProgram);
delete _this.binaryCache[key];
});
}
this.textureManager.dispose();
if (this.canvas != null && (typeof HTMLCanvasElement !== "undefined" && this.canvas instanceof HTMLCanvasElement)) {
this.canvas.remove();
} else {
this.canvas = null;
}
if (this.gpgpuCreatedLocally) {
this.gpgpu.program = null;
this.gpgpu.dispose();
}
this.disposed = true;
};
MathBackendWebGL2.prototype.floatPrecision = function() {
var _this = this;
if (this.floatPrecisionValue == null) {
this.floatPrecisionValue = tf.tidy(function() {
if (!tf.env().get("WEBGL_RENDER_FLOAT32_ENABLED")) {
var debugFlag = tf.env().getBool("DEBUG");
tf.env().set("DEBUG", false);
var underflowCheckValue = _this.abs(tf.scalar(1e-8)).dataSync()[0];
tf.env().set("DEBUG", debugFlag);
if (underflowCheckValue > 0) {
return 32;
}
}
return 16;
});
}
return this.floatPrecisionValue;
};
MathBackendWebGL2.prototype.epsilon = function() {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
};
MathBackendWebGL2.prototype.uploadToGPU = function(dataId) {
var _a;
var texData = this.texData.get(dataId);
var shape = texData.shape, dtype = texData.dtype, values = texData.values, texture = texData.texture, usage = texData.usage, isPacked = texData.isPacked;
if (texture != null) {
return;
}
var shouldTimeProgram = this.activeTimers != null;
var start;
if (shouldTimeProgram) {
start = tf.util.now();
}
var texShape = texData.texShape;
if (texShape == null) {
texShape = getTextureShapeFromLogicalShape(shape, isPacked);
texData.texShape = texShape;
}
if (values != null) {
var shapeAs3D = getShapeAs3D(shape);
var program = void 0;
var width = texShape[1], height = texShape[0];
var isByteArray = values instanceof Uint8Array;
if (isPacked) {
_a = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]), width = _a[0], height = _a[1];
program = new EncodeMatrixPackedProgram(shapeAs3D, [height, width], isByteArray);
} else {
program = new EncodeMatrixProgram(shapeAs3D, [height, width], isByteArray);
}
var tempDenseInputHandle = this.makeTensorInfo([height, width], dtype);
if (isByteArray) {
this.texData.get(tempDenseInputHandle.dataId).usage = TextureUsage.PIXELS;
} else {
this.texData.get(tempDenseInputHandle.dataId).usage = TextureUsage.UPLOAD;
}
this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
var preventEagerUnpacking = true;
var encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, null, preventEagerUnpacking);
var outputTexData = this.texData.get(encodedOutputTarget.dataId);
texData.texture = outputTexData.texture;
texData.texShape = outputTexData.texShape;
texData.isPacked = outputTexData.isPacked;
texData.usage = outputTexData.usage;
this.disposeIntermediateTensorInfo(tempDenseInputHandle);
this.texData.delete(encodedOutputTarget.dataId);
texData.values = null;
if (shouldTimeProgram) {
this.uploadWaitMs += tf.util.now() - start;
}
} else {
var newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
texData.texture = newTexture;
}
};
MathBackendWebGL2.prototype.convertAndCacheOnCPU = function(dataId, float32Values) {
var texData = this.texData.get(dataId);
var dtype = texData.dtype;
this.releaseGPUData(dataId);
if (float32Values != null) {
texData.values = float32ToTypedArray(float32Values, dtype);
}
return texData.values;
};
MathBackendWebGL2.prototype.acquireTexture = function(texShape, texType, dtype, isPacked) {
this.numBytesInGPU += this.computeBytes(texShape, dtype);
if (!this.warnedAboutMemory && this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
var mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
this.warnedAboutMemory = true;
console.warn("High memory usage in GPU: " + mb + " MB, most likely due to a memory leak");
}
return this.textureManager.acquireTexture(texShape, texType, isPacked);
};
MathBackendWebGL2.prototype.computeBytes = function(shape, dtype) {
return shape[0] * shape[1] * tf.util.bytesPerElement(dtype);
};
MathBackendWebGL2.prototype.tryRunOnCpuOrThrow = function(inputs, fn) {
if (this.shouldExecuteOnCPU(inputs)) {
try {
return fn();
} catch (e) {
if (tf.env().getBool("IS_TEST")) {
throw new Error("CPU forwarding failed");
}
}
}
return null;
};
return MathBackendWebGL2;
}(tf.KernelBackend);
function float32ToTypedArray(a, dtype) {
if (dtype === "float32" || dtype === "complex64") {
return a;
} else if (dtype === "int32" || dtype === "bool") {
var result = dtype === "int32" ? new Int32Array(a.length) : new Uint8Array(a.length);
for (var i = 0; i < result.length; ++i) {
result[i] = Math.round(a[i]);
}
return result;
} else {
throw new Error("Unknown dtype " + dtype);
}
}
/** @license See the LICENSE file. */
var version = "2.7.0";
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function forceHalfFloat() {
tf.env().set("WEBGL_FORCE_F16_TEXTURES", true);
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
if (tf.device_util.isBrowser()) {
tf.registerBackend("webgl", function() {
return new MathBackendWebGL();
}, 2);
}
var webgl = {forceHalfFloat};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function identity(args) {
var inputs = args.inputs, backend = args.backend;
var x = inputs.x;
backend.incRef(x.dataId);
return {dataId: x.dataId, shape: x.shape, dtype: x.dtype};
}
var identityConfig = {
kernelName: tf.Identity,
backendName: "webgl",
kernelFunc: identity
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function complex(args) {
var inputs = args.inputs, backend = args.backend;
var real2 = inputs.real, imag2 = inputs.imag;
var complexInfo = backend.makeTensorInfo(real2.shape, "complex64");
var complex2 = backend.texData.get(complexInfo.dataId);
var realTensorInfo = identity({inputs: {x: real2}, backend});
var realData = backend.texData.get(realTensorInfo.dataId);
realData.complexParentRefCount++;
var imagTensorInfo = identity({inputs: {x: imag2}, backend});
var imagData = backend.texData.get(imagTensorInfo.dataId);
imagData.complexParentRefCount++;
complex2.complexTensorInfos = {real: realTensorInfo, imag: imagTensorInfo};
return complexInfo;
}
var complexConfig = {
kernelName: tf.Complex,
backendName: "webgl",
kernelFunc: complex
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CHECK_NAN_SNIPPET_UNARY = "if (isnan(x)) return x;";
var CHECK_NAN_SNIPPET_BINARY = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n";
var CHECK_NAN_SNIPPET_BINARY_PACKED = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n";
function unaryKernelFunc(opSnippet) {
return function(_a) {
var inputs = _a.inputs, backend = _a.backend;
var x = inputs.x;
var webglBackend = backend;
var program = new UnaryOpProgram(x.shape, opSnippet);
return webglBackend.runWebGLProgram(program, [x], x.dtype);
};
}
function binaryKernelFunc(_a) {
var opSnippet = _a.opSnippet, packedOpSnippet = _a.packedOpSnippet, _b = _a.checkOutOfBounds, checkOutOfBounds = _b === void 0 ? false : _b, _c = _a.supportsComplex, supportsComplex = _c === void 0 ? false : _c, cpuKernelImpl = _a.cpuKernelImpl, dtype = _a.dtype;
return function(_a2) {
var inputs = _a2.inputs, backend = _a2.backend;
var _b2 = inputs, a = _b2.a, b = _b2.b;
var webglBackend = backend;
if (supportsComplex && a.dtype === "complex64") {
var aData = webglBackend.texData.get(a.dataId);
var bData = webglBackend.texData.get(b.dataId);
var _c2 = [
[aData.complexTensorInfos.real, bData.complexTensorInfos.real],
[aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]
].map(function(complexParts) {
var aPart = complexParts[0], bPart = complexParts[1];
var aHandle = {
dataId: aPart.dataId,
dtype: aPart.dtype,
shape: a.shape
};
var bHandle = {
dataId: bPart.dataId,
dtype: bPart.dtype,
shape: b.shape
};
var program2 = new BinaryOpProgram(opSnippet, a.shape, b.shape);
return webglBackend.runWebGLProgram(program2, [aHandle, bHandle], tf.upcastType(aPart.dtype, bPart.dtype));
}), real2 = _c2[0], imag2 = _c2[1];
var complexOutput = complex({inputs: {real: real2, imag: imag2}, backend: webglBackend});
webglBackend.disposeIntermediateTensorInfo(real2);
webglBackend.disposeIntermediateTensorInfo(imag2);
return complexOutput;
}
var $dtype = dtype || tf.upcastType(a.dtype, b.dtype);
if (webglBackend.shouldExecuteOnCPU([a, b]) && cpuKernelImpl != null) {
var aData = webglBackend.texData.get(a.dataId);
var bData = webglBackend.texData.get(b.dataId);
var _d = cpuKernelImpl(a.shape, b.shape, aData.values, bData.values, $dtype), outValues = _d[0], outShape = _d[1];
var out = webglBackend.makeTensorInfo(outShape, $dtype);
var outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
return out;
}
var shouldUsePackedProgram = tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS") && packedOpSnippet != null;
var program;
if (shouldUsePackedProgram) {
program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
} else {
program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
}
return webglBackend.runWebGLProgram(program, [a, b], $dtype);
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ADD = "return a + b;";
var addKernelFunc = binaryKernelFunc({
opSnippet: ADD,
packedOpSnippet: ADD,
supportsComplex: true,
cpuKernelImpl: addImplCPU
});
var addConfig = {
kernelName: tf.Add,
backendName: "webgl",
kernelFunc: addKernelFunc
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ATAN2 = CHECK_NAN_SNIPPET_BINARY + "\n return atan(a, b);\n";
var ATAN2_PACKED = "\n vec4 result = atan(a, b);\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + CHECK_NAN_SNIPPET_BINARY_PACKED + "\n return result;\n";
var atan2 = binaryKernelFunc({opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED});
var atan2Config = {
kernelName: tf.Atan2,
backendName: "webgl",
kernelFunc: atan2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
assertNotComplex(x, "avgPool");
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
var dilations = 1;
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in avgPool: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
return identity({inputs: {x}, backend});
}
var avgPoolProgram = new Pool2DProgram(convInfo, "avg", false);
return backend.runWebGLProgram(avgPoolProgram, [x], "float32");
}
var avgPoolConfig = {
kernelName: tf.AvgPool,
backendName: "webgl",
kernelFunc: avgPool
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPoolBackprop(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var dy = inputs.dy, input = inputs.input;
var x = input;
assertNotComplex([dy, input], "avgPoolBackprop");
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad;
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1, pad);
var avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
}
var avgPoolBackpropConfig = {
kernelName: tf.AvgPoolBackprop,
backendName: "webgl",
kernelFunc: avgPoolBackprop
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var BatchNormProgram = function() {
function BatchNormProgram2(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
this.outputShape = [];
this.variableNames = ["x", "mean", "variance"];
tf.backend_util.assertAndGetBroadcastShape(xShape, meanShape);
tf.backend_util.assertAndGetBroadcastShape(xShape, varianceShape);
var offsetSnippet = "0.0";
if (offsetShape != null) {
tf.backend_util.assertAndGetBroadcastShape(xShape, offsetShape);
this.variableNames.push("offset");
offsetSnippet = "getOffsetAtOutCoords()";
}
var scaleSnippet = "1.0";
if (scaleShape != null) {
tf.backend_util.assertAndGetBroadcastShape(xShape, scaleShape);
this.variableNames.push("scale");
scaleSnippet = "getScaleAtOutCoords()";
}
this.outputShape = xShape;
this.userCode = "\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = " + offsetSnippet + ";\n float scale = " + scaleSnippet + ";\n float inv = scale * inversesqrt(variance + float(" + varianceEpsilon + "));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n ";
}
return BatchNormProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var BatchNormPackedProgram = function() {
function BatchNormPackedProgram2(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
this.packedInputs = true;
this.packedOutput = true;
this.variableNames = ["x", "mean", "variance"];
tf.backend_util.assertAndGetBroadcastShape(xShape, meanShape);
tf.backend_util.assertAndGetBroadcastShape(xShape, varianceShape);
var offsetSnippet = "vec4(0.0)";
if (offsetShape != null) {
tf.backend_util.assertAndGetBroadcastShape(xShape, offsetShape);
this.variableNames.push("offset");
offsetSnippet = "getOffsetAtOutCoords()";
}
var scaleSnippet = "vec4(1.0)";
if (scaleShape != null) {
tf.backend_util.assertAndGetBroadcastShape(xShape, scaleShape);
this.variableNames.push("scale");
scaleSnippet = "getScaleAtOutCoords()";
}
this.outputShape = xShape;
this.userCode = "\n void main() {\n vec4 offset = " + offsetSnippet + ";\n vec4 scale = " + scaleSnippet + ";\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(" + varianceEpsilon + "));\n\n setOutput((x - mean) * inv + offset);\n }\n ";
}
return BatchNormPackedProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var batchNorm = function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
var x = inputs.x, mean = inputs.mean, variance = inputs.variance, offset = inputs.offset, scale = inputs.scale;
tf.util.assert(mean.shape.length === variance.shape.length, function() {
return "Batch normalization gradient requires mean and variance to have equal ranks.";
});
tf.util.assert(offset == null || mean.shape.length === offset.shape.length, function() {
return "Batch normalization gradient requires mean and offset to have equal ranks.";
});
tf.util.assert(scale == null || mean.shape.length === scale.shape.length, function() {
return "Batch normalization gradient requires mean and scale to have equal ranks.";
});
var varianceEpsilon = attrs.varianceEpsilon;
if (varianceEpsilon == null) {
varianceEpsilon = 1e-3;
}
var finalInputs = [x, mean, variance];
var offsetShape = null;
if (offset != null) {
offsetShape = offset.shape;
finalInputs.push(offset);
}
var scaleShape = null;
if (scale != null) {
scaleShape = scale.shape;
finalInputs.push(scale);
}
var program = tf.env().getBool("WEBGL_PACK_NORMALIZATION") ? new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) : new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
var output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
return output;
};
var batchNormConfig = {
kernelName: tf.FusedBatchNorm,
backendName: "webgl",
kernelFunc: batchNorm
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var NOT_EQUAL = "return float(a != b);";
var notEqual = binaryKernelFunc({opSnippet: NOT_EQUAL, dtype: "bool"});
var notEqualConfig = {
kernelName: tf.NotEqual,
backendName: "webgl",
kernelFunc: notEqual
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function real(args) {
var inputs = args.inputs, backend = args.backend;
var input = inputs.input;
var inputData = backend.texData.get(input.dataId);
return identity({inputs: {x: inputData.complexTensorInfos.real}, backend});
}
var realConfig = {
kernelName: tf.Real,
backendName: "webgl",
kernelFunc: real
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TO_INT = "return float(int(x));";
function int(input, backend) {
var program = new UnaryOpProgram(input.shape, TO_INT);
var output = backend.runWebGLProgram(program, [input], "int32");
return {dataId: output.dataId, shape: output.shape, dtype: output.dtype};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cast(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var dtype = attrs.dtype;
if (dtype === "complex64") {
if (x.dtype === "complex64") {
return identity({inputs: {x}, backend});
}
var zerosTensor = tf.zeros(x.shape);
var floatX = cast({inputs: {x}, backend, attrs: {dtype: "float32"}});
var result = complex({inputs: {real: floatX, imag: zerosTensor}, backend});
zerosTensor.dispose();
backend.disposeIntermediateTensorInfo(floatX);
return result;
}
if (x.dtype === "complex64") {
var realPart = real({inputs: {input: x}, backend});
var result = cast({inputs: {x: realPart}, backend, attrs: {dtype}});
backend.disposeIntermediateTensorInfo(realPart);
return result;
}
if (!tf.util.hasEncodingLoss(x.dtype, dtype)) {
var result = identity({inputs: {x}, backend});
return {dataId: result.dataId, shape: result.shape, dtype};
}
if (dtype === "int32") {
return int(x, backend);
}
if (dtype === "bool") {
var zerosTensorInfo = backend.makeTensorInfo([], "bool", tf.util.getTypedArrayFromDType("bool", 1));
var binaryInputs = {a: x, b: zerosTensorInfo};
var result = notEqual({inputs: binaryInputs, backend});
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
return result;
}
throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype);
}
var castConfig = {
kernelName: tf.Cast,
backendName: "webgl",
kernelFunc: cast
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ConcatProgram = function() {
function ConcatProgram2(shapes) {
this.outputShape = [];
this.outputShape = tf.backend_util.computeOutShape(shapes, 1);
this.variableNames = shapes.map(function(_, i2) {
return "T" + i2;
});
var offsets = new Array(shapes.length - 1);
offsets[0] = shapes[0][1];
for (var i = 1; i < offsets.length; i++) {
offsets[i] = offsets[i - 1] + shapes[i][1];
}
var snippets = ["if (yC < " + offsets[0] + ") setOutput(getT0(yR, yC));"];
for (var i = 1; i < offsets.length; i++) {
var shift = offsets[i - 1];
snippets.push("else if (yC < " + offsets[i] + ") " + ("setOutput(getT" + i + "(yR, yC-" + shift + "));"));
}
var lastIndex = offsets.length;
var lastShift = offsets[offsets.length - 1];
snippets.push("else setOutput(getT" + lastIndex + "(yR, yC-" + lastShift + "));");
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n " + snippets.join("\n ") + "\n }\n ";
}
return ConcatProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ConcatPackedProgram = function() {
function ConcatPackedProgram2(shapes, axis) {
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = [];
this.outputShape = tf.backend_util.computeOutShape(shapes, axis);
var shape = this.outputShape;
var rank = shape.length;
var dtype = getCoordsDataType(rank);
var coords2 = getChannels("coords", rank);
var channels = ["x", "y", "z", "w", "u", "v"].slice(0, rank);
this.variableNames = shapes.map(function(_, i2) {
return "T" + i2;
});
var offsets = new Array(shapes.length - 1);
offsets[0] = shapes[0][axis];
for (var i = 1; i < offsets.length; i++) {
offsets[i] = offsets[i - 1] + shapes[i][axis];
}
var channel = channels[axis];
var lastChannels = channels.slice(-2);
var allChannels = channels.join();
var getValueSnippet = "if (" + channel + " < " + offsets[0] + ") {\n return getChannel(\n getT0(" + allChannels + "), vec2(" + lastChannels.join() + "));\n }";
for (var i = 1; i < offsets.length; i++) {
var shift_1 = offsets[i - 1];
getValueSnippet += "\n if (" + channel + " < " + offsets[i] + " && " + channel + " >= " + offsets[i - 1] + ") {\n return getChannel(\n getT" + i + "(" + shiftedChannels(channels, channel, shift_1) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift_1) + "));\n }";
}
var lastIndex = offsets.length;
var shift = offsets[offsets.length - 1];
getValueSnippet += "\n return getChannel(\n getT" + lastIndex + "(" + shiftedChannels(channels, channel, shift) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift) + "));";
this.userCode = "\n float getValue(" + channels.map(function(x) {
return "int " + x;
}) + ") {\n " + getValueSnippet + "\n }\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n vec4 result = vec4(getValue(" + coords2 + "), 0., 0., 0.);\n\n " + coords2[rank - 1] + " = " + coords2[rank - 1] + " + 1;\n if (" + coords2[rank - 1] + " < " + shape[rank - 1] + ") {\n result.g = getValue(" + coords2 + ");\n }\n\n " + coords2[rank - 2] + " = " + coords2[rank - 2] + " + 1;\n if (" + coords2[rank - 2] + " < " + shape[rank - 2] + ") {\n result.a = getValue(" + coords2 + ");\n }\n\n " + coords2[rank - 1] + " = " + coords2[rank - 1] + " - 1;\n if (" + coords2[rank - 2] + " < " + shape[rank - 2] + " &&\n " + coords2[rank - 1] + " < " + shape[rank - 1] + ") {\n result.b = getValue(" + coords2 + ");\n }\n setOutput(result);\n }\n ";
}
return ConcatPackedProgram2;
}();
function shiftedChannels(channels, channel, shift) {
var channelIdx = channels.indexOf(channel);
var res = channels.map(function(c, idx) {
if (idx === channelIdx) {
return c + " - " + shift;
} else {
return c;
}
});
return res.join();
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function imag(args) {
var inputs = args.inputs, backend = args.backend;
var input = inputs.input;
var inputData = backend.texData.get(input.dataId);
return identity({inputs: {x: inputData.complexTensorInfos.imag}, backend});
}
var imagConfig = {
kernelName: tf.Imag,
backendName: "webgl",
kernelFunc: imag
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function packedReshape(input, afterShape, backend) {
var input3DShape = [getBatchDim(input.shape)].concat(getRowsCols(input.shape));
var input3D = {
dtype: input.dtype,
shape: input3DShape,
dataId: input.dataId
};
var afterShapeAs3D = [getBatchDim(afterShape)].concat(getRowsCols(afterShape));
var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
var preventEagerUnpackingOfOutput = true;
var output = backend.runWebGLProgram(program, [input3D], input.dtype, null, preventEagerUnpackingOfOutput);
return {dataId: output.dataId, shape: afterShape, dtype: output.dtype};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reshape(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var shape = attrs.shape;
var webglBackend = backend;
var xSize = tf.util.sizeFromShape(x.shape);
var $shape = tf.util.inferFromImplicitShape(shape, xSize);
var $xSize = tf.util.sizeFromShape($shape);
tf.util.assert(xSize === $xSize, function() {
return "The new shape (" + $shape + ") has " + $xSize + " elements and the old " + ("shape (" + x.shape + ") has " + xSize + " elements. The new shape and old ") + "shape must have the same number of elements.";
});
var xTexData = webglBackend.texData.get(x.dataId);
if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) && !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
return packedReshape(x, $shape, webglBackend);
}
webglBackend.incRef(x.dataId);
return {dataId: x.dataId, shape: $shape, dtype: x.dtype};
}
var reshapeConfig = {
kernelName: tf.Reshape,
backendName: "webgl",
kernelFunc: reshape
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function concatImpl(inputs, axis, backend) {
var dtype = inputs[0].dtype;
if (dtype === "complex64") {
var reals = inputs.map(function(t) {
return real({inputs: {input: t}, backend});
});
var imags = inputs.map(function(t) {
return imag({inputs: {input: t}, backend});
});
var realConcated = concatImpl(reals, axis, backend);
var imagConcated = concatImpl(imags, axis, backend);
var result_1 = complex({inputs: {real: realConcated, imag: imagConcated}, backend});
reals.forEach(function(r) {
return backend.disposeIntermediateTensorInfo(r);
});
imags.forEach(function(i) {
return backend.disposeIntermediateTensorInfo(i);
});
backend.disposeIntermediateTensorInfo(realConcated);
backend.disposeIntermediateTensorInfo(imagConcated);
return result_1;
}
if (inputs.length > tf.env().getNumber("WEBGL_MAX_TEXTURES_IN_SHADER")) {
var midIndex = Math.floor(inputs.length / 2);
var leftSide = concatImpl(inputs.slice(0, midIndex), axis, backend);
var rightSide = concatImpl(inputs.slice(midIndex), axis, backend);
var result_2 = concatImpl([leftSide, rightSide], axis, backend);
backend.disposeIntermediateTensorInfo(leftSide);
backend.disposeIntermediateTensorInfo(rightSide);
return result_2;
}
if (tf.env().getBool("WEBGL_PACK_ARRAY_OPERATIONS") && inputs[0].shape.length > 1) {
var program_1 = new ConcatPackedProgram(inputs.map(function(t) {
return t.shape;
}), axis);
return backend.runWebGLProgram(program_1, inputs, dtype);
}
var outShape = tf.backend_util.computeOutShape(inputs.map(function(t) {
return t.shape;
}), axis);
var tensors2D = inputs.map(function(x) {
return reshape({
inputs: {x},
attrs: {shape: [-1, tf.util.sizeFromShape(x.shape.slice(axis))]},
backend
});
});
var program = new ConcatProgram(tensors2D.map(function(t) {
return t.shape;
}));
var result = backend.runWebGLProgram(program, tensors2D, dtype);
tensors2D.forEach(function(r) {
return backend.disposeIntermediateTensorInfo(r);
});
var reshapedResult = reshape({inputs: {x: result}, attrs: {shape: outShape}, backend});
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function concat(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var axis = attrs.axis;
var $axis = tf.util.parseAxisParam(axis, inputs[0].shape)[0];
var outShape = tf.backend_util.computeOutShape(inputs.map(function(t) {
return t.shape;
}), $axis);
if (tf.util.sizeFromShape(outShape) === 0) {
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
}
var $inputs = inputs.filter(function(t) {
return tf.util.sizeFromShape(t.shape) > 0;
});
if ($inputs.length === 1) {
return $inputs[0];
}
var shapes = $inputs.map(function(t) {
return t.shape;
});
tf.backend_util.assertParamsConsistent(shapes, $axis);
return concatImpl($inputs, $axis, backend);
}
var concatConfig = {
kernelName: tf.Concat,
backendName: "webgl",
kernelFunc: concat
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var COS = CHECK_NAN_SNIPPET_UNARY + "\n return cos(x);\n";
var cos = unaryKernelFunc(COS);
var cosConfig = {
kernelName: tf.Cos,
backendName: "webgl",
kernelFunc: cos
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DIV = "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;";
var DIV_PACKED = "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n";
var div = binaryKernelFunc({opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true});
var divConfig = {
kernelName: tf.Div,
backendName: "webgl",
kernelFunc: div
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FFTProgram = function() {
function FFTProgram2(component, inputShape, inverse) {
this.variableNames = ["real", "imag"];
var innerDim = inputShape[1];
this.outputShape = inputShape;
var exponentMultiplierSnippet = inverse ? "2.0 * " + Math.PI : "-2.0 * " + Math.PI;
var resultDenominator = inverse ? innerDim + ".0" : "1.0";
var opString;
if (component === "real") {
opString = "return real * expR - imag * expI;";
} else if (component === "imag") {
opString = "return real * expI + imag * expR;";
} else {
throw new Error('FFT component must be either "real" or "imag", got ' + component + ".");
}
this.userCode = "\n const float exponentMultiplier = " + exponentMultiplierSnippet + ";\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n " + opString + "\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(" + innerDim + ");\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < " + innerDim + "; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / " + resultDenominator + ";\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n ";
}
return FFTProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fftImpl(x, inverse, backend) {
var xData = backend.texData.get(x.dataId);
var inputSize = tf.util.sizeFromShape(x.shape);
var innerDimensionSize = x.shape[x.shape.length - 1];
var batch = inputSize / innerDimensionSize;
var input2D = reshape({inputs: {x}, backend, attrs: {shape: [batch, innerDimensionSize]}});
var xShape = input2D.shape;
var realProgram = new FFTProgram("real", xShape, inverse);
var imagProgram = new FFTProgram("imag", xShape, inverse);
var inputs = [
{
dataId: xData.complexTensorInfos.real.dataId,
dtype: xData.complexTensorInfos.real.dtype,
shape: xShape
},
{
dataId: xData.complexTensorInfos.imag.dataId,
dtype: xData.complexTensorInfos.imag.dtype,
shape: xShape
}
];
var realPart = backend.runWebGLProgram(realProgram, inputs, "float32");
var imagPart = backend.runWebGLProgram(imagProgram, inputs, "float32");
var complexOutput = complex({inputs: {real: realPart, imag: imagPart}, backend});
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(imagPart);
var complexOutputReshaped = reshape({inputs: {x: complexOutput}, backend, attrs: {shape: x.shape}});
backend.disposeIntermediateTensorInfo(complexOutputReshaped);
return complexOutputReshaped;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fft(args) {
var inputs = args.inputs, backend = args.backend;
var input = inputs.input;
return fftImpl(input, false, backend);
}
var fftConfig = {
kernelName: tf.FFT,
backendName: "webgl",
kernelFunc: fft
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FlipLeftRightProgram = function() {
function FlipLeftRightProgram2(imageShape) {
this.variableNames = ["Image"];
this.outputShape = [];
var imageWidth = imageShape[2];
this.outputShape = imageShape;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n\n int coordX = " + imageWidth + " - x;\n float outputValue;\n if(coordX >= 0 && coordX < " + imageWidth + ") {\n outputValue = getImage(coords[0], coords[1], coordX, coords[3]);\n } else {\n outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);\n }\n setOutput(outputValue);\n }\n ";
}
return FlipLeftRightProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var flipLeftRightConfig = {
kernelName: tf.FlipLeftRight,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend;
var image = inputs.image;
var webglBackend = backend;
var program = new FlipLeftRightProgram(image.shape);
var output = webglBackend.runWebGLProgram(program, [image], image.dtype);
return output;
}
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FromPixelsProgram = function() {
function FromPixelsProgram2(outputShape) {
this.variableNames = ["A"];
var glsl = getGlslDifferences();
var height = outputShape[0], width = outputShape[1];
this.outputShape = outputShape;
this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n\n vec4 values = " + glsl.texture2D + "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n setOutput(floor(value * 255.0 + 0.5));\n }\n ";
}
return FromPixelsProgram2;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FromPixelsPackedProgram = function() {
function FromPixelsPackedProgram2(outputShape) {
this.variableNames = ["A"];
this.packedInputs = false;
this.packedOutput = true;
var glsl = getGlslDifferences();
var height = outputShape[0], width = outputShape[1];
this.outputShape = outputShape;
this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n\n vec4 result = vec4(0.);\n\n for(int row=0; row<=1; row++) {\n for(int col=0; col<=1; col++) {\n texC = coords[1] + row;\n depth = coords[2] + col;\n\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n result[row * 2 + col] = floor(value * 255.0 + 0.5);\n }\n }\n\n " + glsl.output + " = result;\n }\n ";
}
return FromPixelsPackedProgram2;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var fromPixelsConfig = {
kernelName: tf.FromPixels,
backendName: "webgl",
kernelFunc: fromPixels
};
var fromPixels2DContext;
function fromPixels(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var pixels = inputs.pixels;
var numChannels = attrs.numChannels;
var isVideo = typeof HTMLVideoElement !== "undefined" && pixels instanceof HTMLVideoElement;
var isImage = typeof HTMLImageElement !== "undefined" && pixels instanceof HTMLImageElement;
var _a = isVideo ? [
pixels.videoWidth,
pixels.videoHeight
] : [pixels.width, pixels.height], width = _a[0], height = _a[1];
var texShape = [height, width];
var outShape = [height, width, numChannels];
if (isImage || isVideo) {
if (fromPixels2DContext == null) {
fromPixels2DContext = document.createElement("canvas").getContext("2d");
}
fromPixels2DContext.canvas.width = width;
fromPixels2DContext.canvas.height = height;
fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
pixels = fromPixels2DContext.canvas;
}
var tempPixelHandle = backend.makeTensorInfo(texShape, "int32");
backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
var program = tf.env().getBool("WEBGL_PACK") ? new FromPixelsPackedProgram(outShape) : new FromPixelsProgram(outShape);
var res = backend.runWebGLProgram(program, [tempPixelHandle], "int32");
backend.disposeData(tempPixelHandle.dataId);
return res;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function ifft(args) {
var inputs = args.inputs, backend = args.backend;
var input = inputs.input;
return fftImpl(input, true, backend);
}
var ifftConfig = {
kernelName: tf.IFFT,
backendName: "webgl",
kernelFunc: ifft
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MeanProgram = function() {
function MeanProgram2(reduceInfo, divisor) {
this.variableNames = ["x"];
var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, inSize = reduceInfo.inSize, outSize = reduceInfo.outSize;
this.outputShape = [batchSize, outSize];
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
var windowSizeVec4Remainder = windowSize % 4;
var updateSnippet = "sumValue += dot(values, ones);";
if (divisor != null) {
var denominator = 1 / divisor;
updateSnippet = "sumValue += dot(values * " + (tf.util.isInt(denominator) ? denominator.toPrecision(2) : denominator) + ", ones);";
}
var checkOutOfBounds = "";
if (inSize % windowSize > 0) {
checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return 0.0;\n }\n ";
}
this.userCode = "\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1), 0.0, 0.0);\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2), 0.0);\n\n " + updateSnippet + "\n }\n setOutput(sumValue);\n }\n ";
}
return MeanProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getReductionStages(inShape) {
var stages = [];
while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
var outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
var windowSize = tf.backend_util.computeOptimalWindowSize(outSize);
stages.push({
inSize: outSize,
windowSize,
outSize: Math.ceil(outSize / windowSize)
});
}
return stages;
}
function reduce(x, dtype, reductionType, backend) {
var reductionStages = getReductionStages(x.shape);
var result = x;
for (var i = 0; i < reductionStages.length; i++) {
var _a = reductionStages[i], inSize = _a.inSize, windowSize = _a.windowSize, outSize = _a.outSize;
var program = void 0;
var previousResult = void 0;
if (reductionType === "mean") {
program = i === 0 ? new MeanProgram({windowSize, inSize, batchSize: x.shape[0], outSize}, inSize) : new MeanProgram({windowSize, inSize, batchSize: x.shape[0], outSize});
} else {
program = new ReduceProgram({windowSize, inSize, batchSize: x.shape[0], outSize}, reductionType);
}
previousResult = result;
result = backend.runWebGLProgram(program, [result], dtype);
if (previousResult.dataId !== x.dataId) {
backend.disposeIntermediateTensorInfo(previousResult);
}
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxImpl$1(x, reduceShape, outShape, backend) {
var inSize = tf.util.sizeFromShape(reduceShape);
var xSize = tf.util.sizeFromShape(x.shape);
var batchSize = xSize / inSize;
var reshapedInput = reshape({inputs: {x}, attrs: {shape: [batchSize, inSize]}, backend});
var reduced = reduce(reshapedInput, x.dtype, "max", backend);
var reshapedOutput = reshape({inputs: {x: reduced}, attrs: {shape: outShape}, backend});
backend.disposeIntermediateTensorInfo(reshapedInput);
backend.disposeIntermediateTensorInfo(reduced);
return reshapedOutput;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TransposeProgram = function() {
function TransposeProgram2(aShape, newDim) {
this.variableNames = ["A"];
var outputShape = new Array(aShape.length);
for (var i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[newDim[i]];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
var dtype = getCoordsDataType(this.rank);
var switched = getSwitchedCoords(newDim);
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + switched + "));\n }\n ";
}
return TransposeProgram2;
}();
function getSwitchedCoords(newDim) {
var rank = newDim.length;
if (rank > 6) {
throw Error("Transpose for rank " + rank + " is not yet supported");
}
var originalOrder = ["resRC.x", "resRC.y", "resRC.z", "resRC.w", "resRC.u", "resRC.v"];
var switchedCoords = new Array(rank);
for (var i = 0; i < newDim.length; i++) {
switchedCoords[newDim[i]] = originalOrder[i];
}
return switchedCoords.join();
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TransposePackedProgram = function() {
function TransposePackedProgram2(aShape, newDim) {
this.variableNames = ["A"];
this.packedInputs = true;
this.packedOutput = true;
var outputShape = new Array(aShape.length);
for (var i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[newDim[i]];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
if (this.rank > 6) {
throw Error("Packed transpose for rank " + this.rank + " is not yet supported.");
}
var dtype = getCoordsDataType(this.rank);
var outputOrder = getVecChannels("rc", this.rank);
var switchedOrder = new Array(this.rank);
for (var i = 0; i < newDim.length; i++) {
switchedOrder[newDim[i]] = outputOrder[i];
}
var innerDims = "vec2(" + switchedOrder.slice(-2).join() + ")";
var nextColumn = "++" + outputOrder[this.rank - 1] + " < " + outputShape[this.rank - 1];
var getc = "getChannel(getA(" + switchedOrder.join() + "), " + innerDims + ")";
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result[0] = " + getc + ";\n if(" + nextColumn + ") {\n result[1] = " + getc + ";\n }\n --" + outputOrder[this.rank - 1] + ";\n if(++" + outputOrder[this.rank - 2] + " < " + outputShape[this.rank - 2] + ") {\n result[2] = " + getc + ";\n if(" + nextColumn + ") {\n result[3] = " + getc + ";\n }\n }\n setOutput(result);\n }\n ";
}
return TransposePackedProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transposeImpl$1(x, perm, backend) {
var program = tf.env().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new TransposePackedProgram(x.shape, perm) : new TransposeProgram(x.shape, perm);
return backend.runWebGLProgram(program, [x], x.dtype);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxConfig = {
kernelName: tf.Max,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var x = inputs.x;
var _b = attrs, reductionIndices = _b.reductionIndices, keepDims = _b.keepDims;
var webglBackend = backend;
var xRank = x.shape.length;
var origAxes = tf.util.parseAxisParam(reductionIndices, x.shape);
var axes = origAxes;
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
var maxInputIsTransposed = permutedAxes != null;
var shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
var maxInput = x;
if (maxInputIsTransposed) {
if (shouldExecuteOnCPU) {
var xTexData = webglBackend.texData.get(maxInput.dataId);
var values = xTexData.values;
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[permutedAxes[i]];
}
var maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
maxInput = webglBackend.makeTensorInfo(newShape, x.dtype);
var maxInputData = webglBackend.texData.get(maxInput.dataId);
maxInputData.values = maxInputValues;
} else {
maxInput = transposeImpl$1(x, permutedAxes, webglBackend);
}
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
}
tf.backend_util.assertAxesAreInnerMostDims("max", axes, xRank);
var _c = tf.backend_util.computeOutAndReduceShapes(maxInput.shape, axes), maxOutShape = _c[0], reduceShape = _c[1];
var outShape = maxOutShape;
if (keepDims) {
outShape = tf.backend_util.expandShapeToKeepDim(maxOutShape, origAxes);
}
var out;
if (shouldExecuteOnCPU) {
var xTexData = webglBackend.texData.get(maxInput.dataId);
var values = xTexData.values;
var outValues = maxImplCPU(values, tf.util.sizeFromShape(reduceShape), outShape, x.dtype);
out = webglBackend.makeTensorInfo(outShape, x.dtype);
var outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
} else {
out = maxImpl$1(maxInput, reduceShape, outShape, webglBackend);
}
if (maxInputIsTransposed) {
webglBackend.disposeIntermediateTensorInfo(maxInput);
}
return out;
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
assertNotComplex(x, "maxPool");
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
var dilations = 1;
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in maxPool: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
return identity({inputs: {x}, backend});
}
var maxPoolProgram = new Pool2DProgram(convInfo, "max", false);
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
}
var maxPoolConfig = {
kernelName: tf.MaxPool,
backendName: "webgl",
kernelFunc: maxPool
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolBackprop(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var dy = inputs.dy, input = inputs.input, output = inputs.output;
var x = input;
assertNotComplex([input, output], "maxPoolBackprop");
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1, pad, dimRoundingMode);
var getPositions = true;
var maxPoolPositionsProgram = new Pool2DProgram(convInfo, "max", getPositions);
var maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
var maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
var result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
backend.disposeIntermediateTensorInfo(maxPoolPositions);
return result;
}
var maxPoolBackpropConfig = {
kernelName: tf.MaxPoolBackprop,
backendName: "webgl",
kernelFunc: maxPoolBackprop
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, backend) {
var program = new Pool2DProgram(convInfo, "max", false);
var poolOutput = backend.runWebGLProgram(program, [x], "float32");
program = new Pool2DProgram(convInfo, "max", true, true, includeBatchInIndex);
var indexOutput = backend.runWebGLProgram(program, [x], "float32");
return [poolOutput, indexOutput];
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxPoolWithArgmaxConfig = {
kernelName: tf.MaxPoolWithArgmax,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var x = inputs.x;
var _b = attrs, filterSize = _b.filterSize, strides = _b.strides, pad = _b.pad, includeBatchInIndex = _b.includeBatchInIndex;
var webglBackend = backend;
tf.util.assert(x.shape.length === 4, function() {
return "Error in maxPool: input must be rank 4 but got rank " + x.shape.length + ".";
});
var dilations = [1, 1];
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function() {
return "Error in maxPool: Either strides or dilations must be 1. " + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
var _c = maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, webglBackend), result = _c[0], indexes = _c[1];
return [result, indexes];
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function meanImpl(x, reduceShape, outShape, backend) {
var inSize = tf.util.sizeFromShape(reduceShape);
var xSize = tf.util.sizeFromShape(x.shape);
var batchSize = xSize / inSize;
var reshapedInput = reshape({inputs: {x}, attrs: {shape: [batchSize, inSize]}, backend});
var reduced = reduce(reshapedInput, "float32", "mean", backend);
var reshapedOutput = reshape({inputs: {x: reduced}, attrs: {shape: outShape}, backend});
backend.disposeIntermediateTensorInfo(reshapedInput);
backend.disposeIntermediateTensorInfo(reduced);
return reshapedOutput;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var meanConfig = {
kernelName: tf.Mean,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var x = inputs.x;
var _b = attrs, keepDims = _b.keepDims, axis = _b.axis;
var webglBackend = backend;
var xRank = x.shape.length;
var origAxes = tf.util.parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
var meanInputIsTransposed = permutedAxes != null;
var shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
var intermediates = [];
var meanInput = x;
if (meanInputIsTransposed) {
if (shouldExecuteOnCPU) {
var xTexData = webglBackend.texData.get(meanInput.dataId);
var values = xTexData.values;
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[permutedAxes[i]];
}
var meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
var meanInputData = webglBackend.texData.get(meanInput.dataId);
meanInputData.values = meanInputValues;
} else {
meanInput = transposeImpl$1(x, permutedAxes, webglBackend);
}
intermediates.push(meanInput);
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
}
tf.backend_util.assertAxesAreInnerMostDims("sum", axes, xRank);
var _c = tf.backend_util.computeOutAndReduceShapes(meanInput.shape, axes), meanOutShape = _c[0], reduceShape = _c[1];
var outShape = meanOutShape;
if (keepDims) {
outShape = tf.backend_util.expandShapeToKeepDim(meanOutShape, origAxes);
}
var out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
for (var _i2 = 0, intermediates_1 = intermediates; _i2 < intermediates_1.length; _i2++) {
var i = intermediates_1[_i2];
webglBackend.disposeIntermediateTensorInfo(i);
}
return out;
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MirrorPadProgram = function() {
function MirrorPadProgram2(xShape, paddings, mode) {
this.variableNames = ["x"];
this.outputShape = paddings.map(function(p, i) {
return p[0] + xShape[i] + p[1];
});
var rank = xShape.length;
var dtype = getCoordsDataType(rank);
var start = paddings.map(function(p) {
return p[0];
}).join(",");
var end = paddings.map(function(p, i) {
return p[0] + xShape[i];
}).join(",");
var unpackedCoords = ["coords[0]", "coords[1]", "coords[2]", "coords[3]"].slice(0, rank);
var offset = mode === "reflect" ? 0 : 1;
if (rank === 1) {
this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start) {\n outC = start * 2 - outC - " + offset + ";\n } else if(outC >= end) {\n outC = (end - 1) * 2 - outC + " + offset + ";\n }\n setOutput(getX(outC - start));\n }\n ";
return;
}
this.userCode = "\n " + dtype + " start = " + dtype + "(" + start + ");\n " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outC = getOutputCoords();\n for (int i = 0; i < " + rank + "; i++) {\n if (outC[i] < start[i]) {\n outC[i] = start[i] * 2 - outC[i] - " + offset + ";\n } else if(outC[i] >= end[i]) {\n outC[i] = (end[i] - 1) * 2 - outC[i] + " + offset + ";\n }\n }\n " + dtype + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n ";
}
return MirrorPadProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MirrorPadPackedProgram = function() {
function MirrorPadPackedProgram2(xShape, paddings, mode) {
this.variableNames = ["x"];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = paddings.map(function(p, i) {
return p[0] + xShape[i] + p[1];
});
var rank = xShape.length;
var dtype = getCoordsDataType(rank);
var start = paddings.map(function(p) {
return p[0];
}).join(",");
var end = paddings.map(function(p, i) {
return p[0] + xShape[i];
}).join(",");
var coords2 = getChannels("rc", rank);
var source = getChannels("source", rank);
var cLimit = coords2[rank - 1] + " < " + this.outputShape[rank - 1];
var innerDims = rank === 1 ? "source" : "vec2(" + source.slice(-2).join() + ")";
var offset = mode === "reflect" ? 0 : 1;
var mainLoop = "";
if (rank === 1) {
var padSetup = "\n " + dtype + " source = rc;\n if (source < start) {\n source = start * 2 - source - " + offset + ";\n } else if (source >= end) {\n source = (end - 1) * 2 - source + " + offset + ";\n }\n source -= start;\n ";
mainLoop = "\n " + dtype + " rc = outputLoc;\n " + padSetup + "\n result[0] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords2[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + padSetup + "\n result[1] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n ";
} else {
var padSetup = "\n " + dtype + " source = rc;\n " + dtype + " lt = " + dtype + "(lessThan(source, start));\n " + dtype + " gte = " + dtype + "(greaterThanEqual(source, end));\n " + dtype + " orig = 1 - (lt + gte);\n source = orig * source +\n lt * (start * 2 - source - " + offset + ") +\n gte * ((end - 1) * 2 - source + " + offset + ");\n source -= start;\n ";
mainLoop = "\n " + dtype + " rc = outputLoc;\n " + padSetup + "\n result[0] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords2[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + padSetup + "\n result[1] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n rc = outputLoc;\n " + coords2[rank - 2] + " += 1;\n if(" + coords2[rank - 2] + " < " + this.outputShape[rank - 2] + ") {\n " + padSetup + "\n result[2] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords2[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + padSetup + "\n result[3] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n }\n ";
}
this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n ";
}
return MirrorPadPackedProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var mirrorPadKernelFunc = function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
var x = inputs.x;
var paddings = attrs.paddings, mode = attrs.mode;
var program = tf.env().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new MirrorPadPackedProgram(x.shape, paddings, mode) : new MirrorPadProgram(x.shape, paddings, mode);
var output = backend.runWebGLProgram(program, [x], x.dtype);
return output;
};
var mirrorPadConfig = {
kernelName: tf.MirrorPad,
backendName: "webgl",
kernelFunc: mirrorPadKernelFunc
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var COMPLEX_MULTIPLY = {
REAL: "return areal * breal - aimag * bimag;",
IMAG: "return areal * bimag + aimag * breal;"
};
var BinaryOpComplexProgram = function() {
function BinaryOpComplexProgram2(op, aShape, bShape) {
this.variableNames = ["AReal", "AImag", "BReal", "BImag"];
this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
this.userCode = "\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n " + op + "\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n ";
}
return BinaryOpComplexProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MUL = "return a * b;";
function multiply(args) {
var inputs = args.inputs, backend = args.backend;
var a = inputs.a, b = inputs.b;
var dtype = tf.backend_util.upcastType(a.dtype, b.dtype);
if (a.dtype === "complex64") {
var aData = backend.texData.get(a.dataId);
var bData = backend.texData.get(b.dataId);
var realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
var imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
var inputs_1 = [
{
dataId: aData.complexTensorInfos.real.dataId,
dtype: aData.complexTensorInfos.real.dtype,
shape: a.shape
},
{
dataId: aData.complexTensorInfos.imag.dataId,
dtype: aData.complexTensorInfos.imag.dtype,
shape: a.shape
},
{
dataId: bData.complexTensorInfos.real.dataId,
dtype: bData.complexTensorInfos.real.dtype,
shape: b.shape
},
{
dataId: bData.complexTensorInfos.imag.dataId,
dtype: bData.complexTensorInfos.imag.dtype,
shape: b.shape
}
];
var realPart = backend.runWebGLProgram(realProgram, inputs_1, "float32");
var imagPart = backend.runWebGLProgram(imagProgram, inputs_1, "float32");
var complexOutput = complex({inputs: {real: realPart, imag: imagPart}, backend});
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(imagPart);
return complexOutput;
}
if (backend.shouldExecuteOnCPU([a, b])) {
var aData = backend.texData.get(a.dataId);
var bData = backend.texData.get(b.dataId);
var _a = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype), outValues = _a[0], outShape = _a[1];
var out = backend.makeTensorInfo(outShape, dtype);
var outData = backend.texData.get(out.dataId);
outData.values = outValues;
return out;
}
var program;
if (tf.env().getBool("WEBGL_PACK_BINARY_OPERATIONS")) {
program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
} else {
program = new BinaryOpProgram(MUL, a.shape, b.shape);
}
return backend.runWebGLProgram(program, [a, b], dtype);
}
var multiplyConfig = {
kernelName: tf.Multiply,
backendName: "webgl",
kernelFunc: multiply
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV3Config = {
kernelName: tf.NonMaxSuppressionV3,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
tf.backend_util.warn("tf.nonMaxSuppression() in webgl locks the UI thread. Call tf.nonMaxSuppressionAsync() instead");
var _b = inputs, boxes = _b.boxes, scores = _b.scores;
var _c = attrs, maxOutputSize = _c.maxOutputSize, iouThreshold = _c.iouThreshold, scoreThreshold = _c.scoreThreshold;
var gpuBackend = backend;
var boxesVals = gpuBackend.readSync(boxes.dataId);
var scoresVals = gpuBackend.readSync(scores.dataId);
var maxOutputSizeVal = maxOutputSize;
var iouThresholdVal = iouThreshold;
var scoreThresholdVal = scoreThreshold;
return tf.kernel_impls.nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal);
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV4Impl = tf.kernel_impls.nonMaxSuppressionV4Impl;
var nonMaxSuppressionV4Config = {
kernelName: tf.NonMaxSuppressionV4,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
tf.backend_util.warn("tf.nonMaxSuppression() in webgl locks the UI thread. Call tf.nonMaxSuppressionAsync() instead");
var _b = inputs, boxes = _b.boxes, scores = _b.scores;
var _c = attrs, maxOutputSize = _c.maxOutputSize, iouThreshold = _c.iouThreshold, scoreThreshold = _c.scoreThreshold, padToMaxOutputSize = _c.padToMaxOutputSize;
var gpuBackend = backend;
var boxesVals = gpuBackend.readSync(boxes.dataId);
var scoresVals = gpuBackend.readSync(scores.dataId);
var _d = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize), selectedIndices = _d.selectedIndices, validOutputs = _d.validOutputs;
return [selectedIndices, validOutputs];
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV5Impl = tf.kernel_impls.nonMaxSuppressionV5Impl;
var nonMaxSuppressionV5Config = {
kernelName: tf.NonMaxSuppressionV5,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
tf.backend_util.warn("tf.nonMaxSuppression() in webgl locks the UI thread. Call tf.nonMaxSuppressionAsync() instead");
var _b = inputs, boxes = _b.boxes, scores = _b.scores;
var _c = attrs, maxOutputSize = _c.maxOutputSize, iouThreshold = _c.iouThreshold, scoreThreshold = _c.scoreThreshold, softNmsSigma = _c.softNmsSigma;
var gpuBackend = backend;
var boxesVals = gpuBackend.readSync(boxes.dataId);
var scoresVals = gpuBackend.readSync(scores.dataId);
var maxOutputSizeVal = maxOutputSize;
var iouThresholdVal = iouThreshold;
var scoreThresholdVal = scoreThreshold;
var softNmsSigmaVal = softNmsSigma;
var _d = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal), selectedIndices = _d.selectedIndices, selectedScores = _d.selectedScores;
return [selectedIndices, selectedScores];
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var RotateProgram = function() {
function RotateProgram2(imageShape, radians, fillValue, center) {
this.variableNames = ["Image"];
this.outputShape = [];
var imageHeight = imageShape[1];
var imageWidth = imageShape[2];
var sinFactor = Math.sin(radians).toFixed(3);
var cosFactor = Math.cos(radians).toFixed(3);
this.outputShape = imageShape;
var _a = tf.backend_util.getImageCenter(center, imageHeight, imageWidth), centerX = _a[0], centerY = _a[1];
var centerXString = centerX.toFixed(3);
var centerYString = centerY.toFixed(3);
var fillSnippet = "";
if (typeof fillValue === "number") {
fillSnippet = "float outputValue = " + fillValue.toFixed(2) + ";";
} else {
fillSnippet = "\n vec3 fill = vec3(" + fillValue.join(",") + ");\n float outputValue = fill[coords[3]];";
}
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n int y = coords[1];\n float coordXFloat = (float(x) - " + centerXString + ") * " + cosFactor + " - (float(y) - " + centerYString + ") * " + sinFactor + ";\n float coordYFloat = (float(x) - " + centerXString + ") * " + sinFactor + " + (float(y) - " + centerYString + ") * " + cosFactor + ";\n int coordX = int(round(coordXFloat + " + centerXString + "));\n int coordY = int(round(coordYFloat + " + centerYString + "));\n " + fillSnippet + "\n if(coordX >= 0 && coordX < " + imageWidth + " && coordY >= 0 && coordY < " + imageHeight + ") {\n outputValue = getImage(coords[0], coordY, coordX, coords[3]);\n }\n setOutput(outputValue);\n }\n ";
}
return RotateProgram2;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var rotateWithOffsetConfig = {
kernelName: tf.RotateWithOffset,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var image = inputs.image;
var _b = attrs, radians = _b.radians, fillValue = _b.fillValue, center = _b.center;
var webglBackend = backend;
var program = new RotateProgram(image.shape, radians, fillValue, center);
var output = webglBackend.runWebGLProgram(program, [image], image.dtype);
return output;
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SIN = CHECK_NAN_SNIPPET_UNARY + "\n return sin(x);\n";
var sin = unaryKernelFunc(SIN);
var sinConfig = {
kernelName: tf.Sin,
backendName: "webgl",
kernelFunc: sin
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SQUARE = "return x * x;";
var square = unaryKernelFunc(SQUARE);
var squareConfig = {
kernelName: tf.Square,
backendName: "webgl",
kernelFunc: square
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SQUARED_DIFFERENCE = "return (a - b) * (a - b);";
var squaredDifference = binaryKernelFunc({opSnippet: SQUARED_DIFFERENCE, packedOpSnippet: SQUARED_DIFFERENCE});
var squaredDifferenceConfig = {
kernelName: tf.SquaredDifference,
backendName: "webgl",
kernelFunc: squaredDifference
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SUB = "return a - b;";
var subKernelFunc = binaryKernelFunc({
opSnippet: SUB,
packedOpSnippet: SUB,
supportsComplex: true,
cpuKernelImpl: subImplCPU
});
var subConfig = {
kernelName: tf.Sub,
backendName: "webgl",
kernelFunc: subKernelFunc
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TAN = "return tan(x);";
var tan = unaryKernelFunc(TAN);
var tanConfig = {
kernelName: tf.Tan,
backendName: "webgl",
kernelFunc: tan
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var transposeConfig = {
kernelName: tf.Transpose,
backendName: "webgl",
kernelFunc: function(_a) {
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
var x = inputs.x;
var perm = attrs.perm;
var webglBackend = backend;
var xRank = x.shape.length;
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}
var out;
if (webglBackend.shouldExecuteOnCPU([x])) {
var xTexData = webglBackend.texData.get(x.dataId);
var values = xTexData.values;
var outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
out = webglBackend.makeTensorInfo(newShape, x.dtype);
var outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
} else {
out = transposeImpl$1(x, perm, webglBackend);
}
return out;
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unique(args) {
var inputs = args.inputs, attrs = args.attrs, backend = args.backend;
var axis = attrs.axis;
var x = inputs.x;
assertNotComplex(x, "unique");
console.warn("WARNING: ", "UI might be locked temporarily as data is being downloaded");
var values = backend.readSync(x.dataId);
var _a = uniqueImplCPU(values, axis, x.shape, x.dtype), outputValues = _a.outputValues, outputShape = _a.outputShape, indices = _a.indices;
return [
backend.makeTensorInfo(outputShape, x.dtype, outputValues),
backend.makeTensorInfo([indices.length], "int32", indices)
];
}
var uniqueConfig = {
kernelName: tf.Unique,
backendName: "webgl",
kernelFunc: unique
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernelConfigs = [
addConfig,
atan2Config,
avgPoolConfig,
avgPoolBackpropConfig,
batchNormConfig,
castConfig,
complexConfig,
concatConfig,
cosConfig,
divConfig,
fftConfig,
flipLeftRightConfig,
fromPixelsConfig,
identityConfig,
ifftConfig,
imagConfig,
maxConfig,
maxPoolConfig,
maxPoolBackpropConfig,
maxPoolWithArgmaxConfig,
meanConfig,
mirrorPadConfig,
multiplyConfig,
nonMaxSuppressionV3Config,
nonMaxSuppressionV4Config,
nonMaxSuppressionV5Config,
notEqualConfig,
realConfig,
reshapeConfig,
rotateWithOffsetConfig,
sinConfig,
squareConfig,
subConfig,
squaredDifferenceConfig,
tanConfig,
transposeConfig,
uniqueConfig
];
for (var _i = 0, kernelConfigs_1 = kernelConfigs; _i < kernelConfigs_1.length; _i++) {
var kernelConfig = kernelConfigs_1[_i];
tf.registerKernel(kernelConfig);
}
exports.GPGPUContext = GPGPUContext;
exports.MathBackendWebGL = MathBackendWebGL;
exports.forceHalfFloat = forceHalfFloat;
exports.gpgpu_util = gpgpu_util;
exports.setWebGLContext = setWebGLContext;
exports.version_webgl = version;
exports.webgl = webgl;
exports.webgl_util = webgl_util;
});
// node_modules/@tensorflow/tfjs/dist/tf.node.js
var require_tf_node = __commonJS((exports) => {
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
"use strict";
Object.defineProperty(exports, "__esModule", {value: true});
var tfjsCore = require_tf_core_node();
var tfjsLayers = require_tf_layers_node();
var tfjsConverter = require_tf_converter_node();
var tfjsData = require_tf_data_node();
var tfjsBackendCpu = require_tf_backend_cpu_node();
var tfjsBackendWebgl = require_tf_backend_webgl_node();
/** @license See the LICENSE file. */
var version = "2.7.0";
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var version$1 = {
"tfjs-core": tfjsCore.version_core,
"tfjs-backend-cpu": tfjsBackendCpu.version_cpu,
"tfjs-backend-webgl": tfjsBackendWebgl.version_webgl,
"tfjs-data": tfjsData.version_data,
"tfjs-layers": tfjsLayers.version_layers,
"tfjs-converter": tfjsConverter.version_converter,
tfjs: version
};
Object.keys(tfjsCore).forEach(function(k) {
if (k !== "default")
Object.defineProperty(exports, k, {
enumerable: true,
get: function() {
return tfjsCore[k];
}
});
});
Object.keys(tfjsLayers).forEach(function(k) {
if (k !== "default")
Object.defineProperty(exports, k, {
enumerable: true,
get: function() {
return tfjsLayers[k];
}
});
});
Object.keys(tfjsConverter).forEach(function(k) {
if (k !== "default")
Object.defineProperty(exports, k, {
enumerable: true,
get: function() {
return tfjsConverter[k];
}
});
});
exports.data = tfjsData;
exports.version = version$1;
});
// src/facemesh/blazeface.js
var require_blazeface = __commonJS((exports) => {
const tf = require_tf_node();
const NUM_LANDMARKS = 6;
function generateAnchors(inputSize) {
const spec = {strides: [inputSize / 16, inputSize / 8], anchors: [2, 6]};
const anchors = [];
for (let i = 0; i < spec.strides.length; i++) {
const stride = spec.strides[i];
const gridRows = Math.floor((inputSize + stride - 1) / stride);
const gridCols = Math.floor((inputSize + stride - 1) / stride);
const anchorsNum = spec.anchors[i];
for (let gridY = 0; gridY < gridRows; gridY++) {
const anchorY = stride * (gridY + 0.5);
for (let gridX = 0; gridX < gridCols; gridX++) {
const anchorX = stride * (gridX + 0.5);
for (let n = 0; n < anchorsNum; n++) {
anchors.push([anchorX, anchorY]);
}
}
}
}
return anchors;
}
const disposeBox = (box) => {
box.startEndTensor.dispose();
box.startPoint.dispose();
box.endPoint.dispose();
};
const createBox = (startEndTensor) => ({
startEndTensor,
startPoint: tf.slice(startEndTensor, [0, 0], [-1, 2]),
endPoint: tf.slice(startEndTensor, [0, 2], [-1, 2])
});
const scaleBox = (box, factors) => {
const starts = tf.mul(box.startPoint, factors);
const ends = tf.mul(box.endPoint, factors);
const newCoordinates = tf.concat2d([starts, ends], 1);
return createBox(newCoordinates);
};
function decodeBounds(boxOutputs, anchors, inputSize) {
const boxStarts = tf.slice(boxOutputs, [0, 1], [-1, 2]);
const centers = tf.add(boxStarts, anchors);
const boxSizes = tf.slice(boxOutputs, [0, 3], [-1, 2]);
const boxSizesNormalized = tf.div(boxSizes, inputSize);
const centersNormalized = tf.div(centers, inputSize);
const halfBoxSize = tf.div(boxSizesNormalized, 2);
const starts = tf.sub(centersNormalized, halfBoxSize);
const ends = tf.add(centersNormalized, halfBoxSize);
const startNormalized = tf.mul(starts, inputSize);
const endNormalized = tf.mul(ends, inputSize);
const concatAxis = 1;
return tf.concat2d([startNormalized, endNormalized], concatAxis);
}
function scaleBoxFromPrediction(face, scaleFactor) {
return tf.tidy(() => {
const box = face["box"] ? face["box"] : face;
return scaleBox(box, scaleFactor).startEndTensor.squeeze();
});
}
class BlazeFaceModel {
constructor(model, config) {
this.blazeFaceModel = model;
this.width = config.detector.inputSize;
this.height = config.detector.inputSize;
this.maxFaces = config.detector.maxFaces;
this.anchorsData = generateAnchors(config.detector.inputSize);
this.anchors = tf.tensor2d(this.anchorsData);
this.inputSize = tf.tensor1d([this.width, this.height]);
this.iouThreshold = config.detector.iouThreshold;
this.scaleFaces = 0.8;
this.scoreThreshold = config.detector.scoreThreshold;
}
async getBoundingBoxes(inputImage) {
if (!inputImage || inputImage.isDisposedInternal || inputImage.shape.length !== 4 || inputImage.shape[1] < 1 || inputImage.shape[2] < 1)
return null;
const [detectedOutputs, boxes, scores] = tf.tidy(() => {
const resizedImage = inputImage.resizeBilinear([this.width, this.height]);
const normalizedImage = tf.mul(tf.sub(resizedImage.div(255), 0.5), 2);
const batchedPrediction = this.blazeFaceModel.predict(normalizedImage);
let prediction;
if (Array.isArray(batchedPrediction)) {
const sorted = batchedPrediction.sort((a, b) => a.size - b.size);
const concat384 = tf.concat([sorted[0], sorted[2]], 2);
const concat512 = tf.concat([sorted[1], sorted[3]], 2);
const concat = tf.concat([concat512, concat384], 1);
prediction = concat.squeeze(0);
} else {
prediction = batchedPrediction.squeeze();
}
const decodedBounds = decodeBounds(prediction, this.anchors, this.inputSize);
const logits = tf.slice(prediction, [0, 0], [-1, 1]);
const scoresOut = tf.sigmoid(logits).squeeze();
return [prediction, decodedBounds, scoresOut];
});
const boxIndicesTensor = await tf.image.nonMaxSuppressionAsync(boxes, scores, this.maxFaces, this.iouThreshold, this.scoreThreshold);
const boxIndices = await boxIndicesTensor.array();
boxIndicesTensor.dispose();
const boundingBoxesMap = boxIndices.map((boxIndex) => tf.slice(boxes, [boxIndex, 0], [1, -1]));
const boundingBoxes = await Promise.all(boundingBoxesMap.map(async (boundingBox) => {
const vals = await boundingBox.array();
boundingBox.dispose();
return vals;
}));
const annotatedBoxes = [];
for (let i = 0; i < boundingBoxes.length; i++) {
const boundingBox = boundingBoxes[i];
const box = createBox(boundingBox);
const boxIndex = boxIndices[i];
const anchor = this.anchorsData[boxIndex];
const sliced = tf.slice(detectedOutputs, [boxIndex, NUM_LANDMARKS - 1], [1, -1]);
const squeezed = sliced.squeeze();
const landmarks = squeezed.reshape([NUM_LANDMARKS, -1]);
const probability = tf.slice(scores, [boxIndex], [1]);
const annotatedBox = {box, landmarks, probability, anchor};
annotatedBoxes.push(annotatedBox);
sliced.dispose();
squeezed.dispose();
}
detectedOutputs.dispose();
boxes.dispose();
scores.dispose();
detectedOutputs.dispose();
return {
boxes: annotatedBoxes,
scaleFactor: [inputImage.shape[2] / this.width, inputImage.shape[1] / this.height]
};
}
async estimateFaces(input) {
const {boxes, scaleFactor} = await this.getBoundingBoxes(input);
return Promise.all(boxes.map(async (face) => {
const scaledBox = scaleBoxFromPrediction(face, scaleFactor);
const [landmarkData, boxData, probabilityData] = await Promise.all([face.landmarks, scaledBox, face.probability].map(async (d) => d.array()));
const anchor = face.anchor;
const [scaleFactorX, scaleFactorY] = scaleFactor;
const scaledLandmarks = landmarkData.map((landmark) => [
(landmark[0] + anchor[0]) * scaleFactorX,
(landmark[1] + anchor[1]) * scaleFactorY
]);
const normalizedFace = {
topLeft: boxData.slice(0, 2),
bottomRight: boxData.slice(2),
landmarks: scaledLandmarks,
probability: probabilityData
};
disposeBox(face.box);
face.landmarks.dispose();
face.probability.dispose();
scaledBox.dispose();
return normalizedFace;
}));
}
}
async function load(config) {
const blazeface = await tf.loadGraphModel(config.detector.modelPath, {fromTFHub: config.detector.modelPath.includes("tfhub.dev")});
const model = new BlazeFaceModel(blazeface, config);
return model;
}
exports.load = load;
exports.BlazeFaceModel = BlazeFaceModel;
exports.disposeBox = disposeBox;
});
// src/facemesh/keypoints.js
var require_keypoints = __commonJS((exports) => {
exports.MESH_ANNOTATIONS = {
silhouette: [
10,
338,
297,
332,
284,
251,
389,
356,
454,
323,
361,
288,
397,
365,
379,
378,
400,
377,
152,
148,
176,
149,
150,
136,
172,
58,
132,
93,
234,
127,
162,
21,
54,
103,
67,
109
],
lipsUpperOuter: [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291],
lipsLowerOuter: [146, 91, 181, 84, 17, 314, 405, 321, 375, 291],
lipsUpperInner: [78, 191, 80, 81, 82, 13, 312, 311, 310, 415, 308],
lipsLowerInner: [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308],
rightEyeUpper0: [246, 161, 160, 159, 158, 157, 173],
rightEyeLower0: [33, 7, 163, 144, 145, 153, 154, 155, 133],
rightEyeUpper1: [247, 30, 29, 27, 28, 56, 190],
rightEyeLower1: [130, 25, 110, 24, 23, 22, 26, 112, 243],
rightEyeUpper2: [113, 225, 224, 223, 222, 221, 189],
rightEyeLower2: [226, 31, 228, 229, 230, 231, 232, 233, 244],
rightEyeLower3: [143, 111, 117, 118, 119, 120, 121, 128, 245],
rightEyebrowUpper: [156, 70, 63, 105, 66, 107, 55, 193],
rightEyebrowLower: [35, 124, 46, 53, 52, 65],
rightEyeIris: [473, 474, 475, 476, 477],
leftEyeUpper0: [466, 388, 387, 386, 385, 384, 398],
leftEyeLower0: [263, 249, 390, 373, 374, 380, 381, 382, 362],
leftEyeUpper1: [467, 260, 259, 257, 258, 286, 414],
leftEyeLower1: [359, 255, 339, 254, 253, 252, 256, 341, 463],
leftEyeUpper2: [342, 445, 444, 443, 442, 441, 413],
leftEyeLower2: [446, 261, 448, 449, 450, 451, 452, 453, 464],
leftEyeLower3: [372, 340, 346, 347, 348, 349, 350, 357, 465],
leftEyebrowUpper: [383, 300, 293, 334, 296, 336, 285, 417],
leftEyebrowLower: [265, 353, 276, 283, 282, 295],
leftEyeIris: [468, 469, 470, 471, 472],
midwayBetweenEyes: [168],
noseTip: [1],
noseBottom: [2],
noseRightCorner: [98],
noseLeftCorner: [327],
rightCheek: [205],
leftCheek: [425]
};
exports.MESH_TO_IRIS_INDICES_MAP = [
{key: "EyeUpper0", indices: [9, 10, 11, 12, 13, 14, 15]},
{key: "EyeUpper1", indices: [25, 26, 27, 28, 29, 30, 31]},
{key: "EyeUpper2", indices: [41, 42, 43, 44, 45, 46, 47]},
{key: "EyeLower0", indices: [0, 1, 2, 3, 4, 5, 6, 7, 8]},
{key: "EyeLower1", indices: [16, 17, 18, 19, 20, 21, 22, 23, 24]},
{key: "EyeLower2", indices: [32, 33, 34, 35, 36, 37, 38, 39, 40]},
{key: "EyeLower3", indices: [54, 55, 56, 57, 58, 59, 60, 61, 62]},
{key: "EyebrowUpper", indices: [63, 64, 65, 66, 67, 68, 69, 70]},
{key: "EyebrowLower", indices: [48, 49, 50, 51, 52, 53]}
];
});
// src/facemesh/box.js
var require_box = __commonJS((exports) => {
const tf = require_tf_node();
function scaleBoxCoordinates(box, factor) {
const startPoint = [box.startPoint[0] * factor[0], box.startPoint[1] * factor[1]];
const endPoint = [box.endPoint[0] * factor[0], box.endPoint[1] * factor[1]];
return {startPoint, endPoint};
}
exports.scaleBoxCoordinates = scaleBoxCoordinates;
function getBoxSize(box) {
return [
Math.abs(box.endPoint[0] - box.startPoint[0]),
Math.abs(box.endPoint[1] - box.startPoint[1])
];
}
exports.getBoxSize = getBoxSize;
function getBoxCenter(box) {
return [
box.startPoint[0] + (box.endPoint[0] - box.startPoint[0]) / 2,
box.startPoint[1] + (box.endPoint[1] - box.startPoint[1]) / 2
];
}
exports.getBoxCenter = getBoxCenter;
function cutBoxFromImageAndResize(box, image, cropSize) {
const h = image.shape[1];
const w = image.shape[2];
const boxes = [[
box.startPoint[1] / h,
box.startPoint[0] / w,
box.endPoint[1] / h,
box.endPoint[0] / w
]];
return tf.image.cropAndResize(image, boxes, [0], cropSize);
}
exports.cutBoxFromImageAndResize = cutBoxFromImageAndResize;
function enlargeBox(box, factor = 1.5) {
const center = getBoxCenter(box);
const size = getBoxSize(box);
const newHalfSize = [factor * size[0] / 2, factor * size[1] / 2];
const startPoint = [center[0] - newHalfSize[0], center[1] - newHalfSize[1]];
const endPoint = [center[0] + newHalfSize[0], center[1] + newHalfSize[1]];
return {startPoint, endPoint, landmarks: box.landmarks};
}
exports.enlargeBox = enlargeBox;
function squarifyBox(box) {
const centers = getBoxCenter(box);
const size = getBoxSize(box);
const maxEdge = Math.max(...size);
const halfSize = maxEdge / 2;
const startPoint = [centers[0] - halfSize, centers[1] - halfSize];
const endPoint = [centers[0] + halfSize, centers[1] + halfSize];
return {startPoint, endPoint, landmarks: box.landmarks};
}
exports.squarifyBox = squarifyBox;
});
// src/facemesh/util.js
var require_util2 = __commonJS((exports) => {
exports.IDENTITY_MATRIX = [[1, 0, 0], [0, 1, 0], [0, 0, 1]];
function normalizeRadians(angle) {
return angle - 2 * Math.PI * Math.floor((angle + Math.PI) / (2 * Math.PI));
}
exports.normalizeRadians = normalizeRadians;
function computeRotation(point1, point2) {
const radians = Math.PI / 2 - Math.atan2(-(point2[1] - point1[1]), point2[0] - point1[0]);
return normalizeRadians(radians);
}
exports.computeRotation = computeRotation;
function radToDegrees(rad) {
return rad * 180 / Math.PI;
}
exports.radToDegrees = radToDegrees;
function buildTranslationMatrix(x, y) {
return [[1, 0, x], [0, 1, y], [0, 0, 1]];
}
function dot(v1, v2) {
let product = 0;
for (let i = 0; i < v1.length; i++) {
product += v1[i] * v2[i];
}
return product;
}
exports.dot = dot;
function getColumnFrom2DArr(arr, columnIndex) {
const column = [];
for (let i = 0; i < arr.length; i++) {
column.push(arr[i][columnIndex]);
}
return column;
}
exports.getColumnFrom2DArr = getColumnFrom2DArr;
function multiplyTransformMatrices(mat1, mat2) {
const product = [];
const size = mat1.length;
for (let row = 0; row < size; row++) {
product.push([]);
for (let col = 0; col < size; col++) {
product[row].push(dot(mat1[row], getColumnFrom2DArr(mat2, col)));
}
}
return product;
}
function buildRotationMatrix(rotation, center) {
const cosA = Math.cos(rotation);
const sinA = Math.sin(rotation);
const rotationMatrix = [[cosA, -sinA, 0], [sinA, cosA, 0], [0, 0, 1]];
const translationMatrix = buildTranslationMatrix(center[0], center[1]);
const translationTimesRotation = multiplyTransformMatrices(translationMatrix, rotationMatrix);
const negativeTranslationMatrix = buildTranslationMatrix(-center[0], -center[1]);
return multiplyTransformMatrices(translationTimesRotation, negativeTranslationMatrix);
}
exports.buildRotationMatrix = buildRotationMatrix;
function invertTransformMatrix(matrix) {
const rotationComponent = [[matrix[0][0], matrix[1][0]], [matrix[0][1], matrix[1][1]]];
const translationComponent = [matrix[0][2], matrix[1][2]];
const invertedTranslation = [
-dot(rotationComponent[0], translationComponent),
-dot(rotationComponent[1], translationComponent)
];
return [
rotationComponent[0].concat(invertedTranslation[0]),
rotationComponent[1].concat(invertedTranslation[1]),
[0, 0, 1]
];
}
exports.invertTransformMatrix = invertTransformMatrix;
function rotatePoint(homogeneousCoordinate, rotationMatrix) {
return [
dot(homogeneousCoordinate, rotationMatrix[0]),
dot(homogeneousCoordinate, rotationMatrix[1])
];
}
exports.rotatePoint = rotatePoint;
function xyDistanceBetweenPoints(a, b) {
return Math.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2);
}
exports.xyDistanceBetweenPoints = xyDistanceBetweenPoints;
});
// src/facemesh/pipeline.js
var require_pipeline = __commonJS((exports) => {
const tf = require_tf_node();
const bounding = require_box();
const keypoints = require_keypoints();
const util = require_util2();
const LANDMARKS_COUNT = 468;
const UPDATE_REGION_OF_INTEREST_IOU_THRESHOLD = 0.25;
const MESH_MOUTH_INDEX = 13;
const MESH_KEYPOINTS_LINE_OF_SYMMETRY_INDICES = [MESH_MOUTH_INDEX, keypoints.MESH_ANNOTATIONS["midwayBetweenEyes"][0]];
const BLAZEFACE_MOUTH_INDEX = 3;
const BLAZEFACE_NOSE_INDEX = 2;
const BLAZEFACE_KEYPOINTS_LINE_OF_SYMMETRY_INDICES = [BLAZEFACE_MOUTH_INDEX, BLAZEFACE_NOSE_INDEX];
const LEFT_EYE_OUTLINE = keypoints.MESH_ANNOTATIONS["leftEyeLower0"];
const LEFT_EYE_BOUNDS = [LEFT_EYE_OUTLINE[0], LEFT_EYE_OUTLINE[LEFT_EYE_OUTLINE.length - 1]];
const RIGHT_EYE_OUTLINE = keypoints.MESH_ANNOTATIONS["rightEyeLower0"];
const RIGHT_EYE_BOUNDS = [RIGHT_EYE_OUTLINE[0], RIGHT_EYE_OUTLINE[RIGHT_EYE_OUTLINE.length - 1]];
const IRIS_UPPER_CENTER_INDEX = 3;
const IRIS_LOWER_CENTER_INDEX = 4;
const IRIS_IRIS_INDEX = 71;
const IRIS_NUM_COORDINATES = 76;
function replaceRawCoordinates(rawCoords, newCoords, prefix, keys) {
for (let i = 0; i < keypoints.MESH_TO_IRIS_INDICES_MAP.length; i++) {
const {key, indices} = keypoints.MESH_TO_IRIS_INDICES_MAP[i];
const originalIndices = keypoints.MESH_ANNOTATIONS[`${prefix}${key}`];
const shouldReplaceAllKeys = keys == null;
if (shouldReplaceAllKeys || keys.includes(key)) {
for (let j = 0; j < indices.length; j++) {
const index = indices[j];
rawCoords[originalIndices[j]] = [
newCoords[index][0],
newCoords[index][1],
(newCoords[index][2] + rawCoords[originalIndices[j]][2]) / 2
];
}
}
}
}
class Pipeline {
constructor(boundingBoxDetector, meshDetector, irisModel, config) {
this.regionsOfInterest = [];
this.runsWithoutFaceDetector = 0;
this.boundingBoxDetector = boundingBoxDetector;
this.meshDetector = meshDetector;
this.irisModel = irisModel;
this.meshWidth = config.mesh.inputSize;
this.meshHeight = config.mesh.inputSize;
this.irisSize = config.iris.inputSize;
this.irisEnlarge = config.iris.enlargeFactor;
}
transformRawCoords(rawCoords, box, angle, rotationMatrix) {
const boxSize = bounding.getBoxSize({startPoint: box.startPoint, endPoint: box.endPoint});
const scaleFactor = [boxSize[0] / this.meshWidth, boxSize[1] / this.meshHeight];
const coordsScaled = rawCoords.map((coord) => [
scaleFactor[0] * (coord[0] - this.meshWidth / 2),
scaleFactor[1] * (coord[1] - this.meshHeight / 2),
coord[2]
]);
const coordsRotationMatrix = util.buildRotationMatrix(angle, [0, 0]);
const coordsRotated = coordsScaled.map((coord) => [...util.rotatePoint(coord, coordsRotationMatrix), coord[2]]);
const inverseRotationMatrix = util.invertTransformMatrix(rotationMatrix);
const boxCenter = [...bounding.getBoxCenter({startPoint: box.startPoint, endPoint: box.endPoint}), 1];
const originalBoxCenter = [
util.dot(boxCenter, inverseRotationMatrix[0]),
util.dot(boxCenter, inverseRotationMatrix[1])
];
return coordsRotated.map((coord) => [
coord[0] + originalBoxCenter[0],
coord[1] + originalBoxCenter[1],
coord[2]
]);
}
getLeftToRightEyeDepthDifference(rawCoords) {
const leftEyeZ = rawCoords[LEFT_EYE_BOUNDS[0]][2];
const rightEyeZ = rawCoords[RIGHT_EYE_BOUNDS[0]][2];
return leftEyeZ - rightEyeZ;
}
getEyeBox(rawCoords, face, eyeInnerCornerIndex, eyeOuterCornerIndex, flip = false) {
const box = bounding.squarifyBox(bounding.enlargeBox(this.calculateLandmarksBoundingBox([rawCoords[eyeInnerCornerIndex], rawCoords[eyeOuterCornerIndex]]), this.irisEnlarge));
const boxSize = bounding.getBoxSize(box);
let crop = tf.image.cropAndResize(face, [[
box.startPoint[1] / this.meshHeight,
box.startPoint[0] / this.meshWidth,
box.endPoint[1] / this.meshHeight,
box.endPoint[0] / this.meshWidth
]], [0], [this.irisSize, this.irisSize]);
if (flip) {
crop = tf.image.flipLeftRight(crop);
}
return {box, boxSize, crop};
}
getEyeCoords(eyeData, eyeBox, eyeBoxSize, flip = false) {
const eyeRawCoords = [];
for (let i = 0; i < IRIS_NUM_COORDINATES; i++) {
const x = eyeData[i * 3];
const y = eyeData[i * 3 + 1];
const z = eyeData[i * 3 + 2];
eyeRawCoords.push([
(flip ? 1 - x / this.irisSize : x / this.irisSize) * eyeBoxSize[0] + eyeBox.startPoint[0],
y / this.irisSize * eyeBoxSize[1] + eyeBox.startPoint[1],
z
]);
}
return {rawCoords: eyeRawCoords, iris: eyeRawCoords.slice(IRIS_IRIS_INDEX)};
}
getAdjustedIrisCoords(rawCoords, irisCoords, direction) {
const upperCenterZ = rawCoords[keypoints.MESH_ANNOTATIONS[`${direction}EyeUpper0`][IRIS_UPPER_CENTER_INDEX]][2];
const lowerCenterZ = rawCoords[keypoints.MESH_ANNOTATIONS[`${direction}EyeLower0`][IRIS_LOWER_CENTER_INDEX]][2];
const averageZ = (upperCenterZ + lowerCenterZ) / 2;
return irisCoords.map((coord, i) => {
let z = averageZ;
if (i === 2) {
z = upperCenterZ;
} else if (i === 4) {
z = lowerCenterZ;
}
return [coord[0], coord[1], z];
});
}
async predict(input, config) {
this.skipFrames = config.detector.skipFrames;
this.maxFaces = config.detector.maxFaces;
this.runsWithoutFaceDetector++;
if (this.shouldUpdateRegionsOfInterest()) {
const detector = await this.boundingBoxDetector.getBoundingBoxes(input);
if (detector.boxes.length === 0) {
this.regionsOfInterest = [];
return null;
}
const scaledBoxes = detector.boxes.map((prediction) => {
const startPoint = prediction.box.startPoint.squeeze();
const endPoint = prediction.box.endPoint.squeeze();
const predictionBox = {
startPoint: startPoint.arraySync(),
endPoint: endPoint.arraySync()
};
startPoint.dispose();
endPoint.dispose();
const scaledBox = bounding.scaleBoxCoordinates(predictionBox, detector.scaleFactor);
const enlargedBox = bounding.enlargeBox(scaledBox);
const landmarks = prediction.landmarks.arraySync();
prediction.box.startPoint.dispose();
prediction.box.endPoint.dispose();
prediction.landmarks.dispose();
prediction.probability.dispose();
return {...enlargedBox, landmarks};
});
this.updateRegionsOfInterest(scaledBoxes);
this.runsWithoutFaceDetector = 0;
}
const results = tf.tidy(() => this.regionsOfInterest.map((box, i) => {
let angle = 0;
const boxLandmarksFromMeshModel = box.landmarks.length >= LANDMARKS_COUNT;
let [indexOfMouth, indexOfForehead] = MESH_KEYPOINTS_LINE_OF_SYMMETRY_INDICES;
if (boxLandmarksFromMeshModel === false) {
[indexOfMouth, indexOfForehead] = BLAZEFACE_KEYPOINTS_LINE_OF_SYMMETRY_INDICES;
}
angle = util.computeRotation(box.landmarks[indexOfMouth], box.landmarks[indexOfForehead]);
const faceCenter = bounding.getBoxCenter({startPoint: box.startPoint, endPoint: box.endPoint});
const faceCenterNormalized = [faceCenter[0] / input.shape[2], faceCenter[1] / input.shape[1]];
let rotatedImage = input;
let rotationMatrix = util.IDENTITY_MATRIX;
if (angle !== 0) {
rotatedImage = tf.image.rotateWithOffset(input, angle, 0, faceCenterNormalized);
rotationMatrix = util.buildRotationMatrix(-angle, faceCenter);
}
const boxCPU = {startPoint: box.startPoint, endPoint: box.endPoint};
const face = bounding.cutBoxFromImageAndResize(boxCPU, rotatedImage, [this.meshHeight, this.meshWidth]).div(255);
const [, flag, coords] = this.meshDetector.predict(face);
const coordsReshaped = tf.reshape(coords, [-1, 3]);
let rawCoords = coordsReshaped.arraySync();
if (config.iris.enabled) {
const {box: leftEyeBox, boxSize: leftEyeBoxSize, crop: leftEyeCrop} = this.getEyeBox(rawCoords, face, LEFT_EYE_BOUNDS[0], LEFT_EYE_BOUNDS[1], true);
const {box: rightEyeBox, boxSize: rightEyeBoxSize, crop: rightEyeCrop} = this.getEyeBox(rawCoords, face, RIGHT_EYE_BOUNDS[0], RIGHT_EYE_BOUNDS[1]);
const eyePredictions = this.irisModel.predict(tf.concat([leftEyeCrop, rightEyeCrop]));
const eyePredictionsData = eyePredictions.dataSync();
eyePredictions.dispose();
const leftEyeData = eyePredictionsData.slice(0, IRIS_NUM_COORDINATES * 3);
const {rawCoords: leftEyeRawCoords, iris: leftIrisRawCoords} = this.getEyeCoords(leftEyeData, leftEyeBox, leftEyeBoxSize, true);
const rightEyeData = eyePredictionsData.slice(IRIS_NUM_COORDINATES * 3);
const {rawCoords: rightEyeRawCoords, iris: rightIrisRawCoords} = this.getEyeCoords(rightEyeData, rightEyeBox, rightEyeBoxSize);
const leftToRightEyeDepthDifference = this.getLeftToRightEyeDepthDifference(rawCoords);
if (Math.abs(leftToRightEyeDepthDifference) < 30) {
replaceRawCoordinates(rawCoords, leftEyeRawCoords, "left");
replaceRawCoordinates(rawCoords, rightEyeRawCoords, "right");
} else if (leftToRightEyeDepthDifference < 1) {
replaceRawCoordinates(rawCoords, leftEyeRawCoords, "left", ["EyeUpper0", "EyeLower0"]);
} else {
replaceRawCoordinates(rawCoords, rightEyeRawCoords, "right", ["EyeUpper0", "EyeLower0"]);
}
const adjustedLeftIrisCoords = this.getAdjustedIrisCoords(rawCoords, leftIrisRawCoords, "left");
const adjustedRightIrisCoords = this.getAdjustedIrisCoords(rawCoords, rightIrisRawCoords, "right");
rawCoords = rawCoords.concat(adjustedLeftIrisCoords).concat(adjustedRightIrisCoords);
}
const transformedCoordsData = this.transformRawCoords(rawCoords, box, angle, rotationMatrix);
tf.dispose(rawCoords);
const landmarksBox = bounding.enlargeBox(this.calculateLandmarksBoundingBox(transformedCoordsData));
const confidence = flag.squeeze();
tf.dispose(flag);
if (config.mesh.enabled) {
const transformedCoords = tf.tensor2d(transformedCoordsData);
this.regionsOfInterest[i] = {...landmarksBox, landmarks: transformedCoords.arraySync()};
const prediction2 = {
coords: transformedCoords,
box: landmarksBox,
confidence,
image: face
};
return prediction2;
}
const prediction = {
coords: null,
box: landmarksBox,
confidence,
image: face
};
return prediction;
}));
return results;
}
updateRegionsOfInterest(boxes) {
for (let i = 0; i < boxes.length; i++) {
const box = boxes[i];
const previousBox = this.regionsOfInterest[i];
let iou = 0;
if (previousBox && previousBox.startPoint) {
const [boxStartX, boxStartY] = box.startPoint;
const [boxEndX, boxEndY] = box.endPoint;
const [previousBoxStartX, previousBoxStartY] = previousBox.startPoint;
const [previousBoxEndX, previousBoxEndY] = previousBox.endPoint;
const xStartMax = Math.max(boxStartX, previousBoxStartX);
const yStartMax = Math.max(boxStartY, previousBoxStartY);
const xEndMin = Math.min(boxEndX, previousBoxEndX);
const yEndMin = Math.min(boxEndY, previousBoxEndY);
const intersection = (xEndMin - xStartMax) * (yEndMin - yStartMax);
const boxArea = (boxEndX - boxStartX) * (boxEndY - boxStartY);
const previousBoxArea = (previousBoxEndX - previousBoxStartX) * (previousBoxEndY - boxStartY);
iou = intersection / (boxArea + previousBoxArea - intersection);
}
if (iou < UPDATE_REGION_OF_INTEREST_IOU_THRESHOLD) {
this.regionsOfInterest[i] = box;
}
}
this.regionsOfInterest = this.regionsOfInterest.slice(0, boxes.length);
}
clearRegionOfInterest(index) {
if (this.regionsOfInterest[index] != null) {
this.regionsOfInterest = [
...this.regionsOfInterest.slice(0, index),
...this.regionsOfInterest.slice(index + 1)
];
}
}
shouldUpdateRegionsOfInterest() {
if (this.regionsOfInterest.length === 0)
return true;
return this.regionsOfInterest.length !== this.maxFaces && this.runsWithoutFaceDetector >= this.skipFrames;
}
calculateLandmarksBoundingBox(landmarks) {
const xs = landmarks.map((d) => d[0]);
const ys = landmarks.map((d) => d[1]);
const startPoint = [Math.min(...xs), Math.min(...ys)];
const endPoint = [Math.max(...xs), Math.max(...ys)];
return {startPoint, endPoint, landmarks};
}
}
exports.Pipeline = Pipeline;
});
// src/facemesh/uvcoords.js
var require_uvcoords = __commonJS((exports) => {
exports.UV_COORDS = [
[0.499976992607117, 0.652534008026123],
[0.500025987625122, 0.547487020492554],
[0.499974012374878, 0.602371990680695],
[0.482113003730774, 0.471979022026062],
[0.500150978565216, 0.527155995368958],
[0.499909996986389, 0.498252987861633],
[0.499523013830185, 0.40106201171875],
[0.289712011814117, 0.380764007568359],
[0.499954998493195, 0.312398016452789],
[0.499987006187439, 0.269918978214264],
[0.500023007392883, 0.107050001621246],
[0.500023007392883, 0.666234016418457],
[0.5000159740448, 0.679224014282227],
[0.500023007392883, 0.692348003387451],
[0.499976992607117, 0.695277988910675],
[0.499976992607117, 0.70593398809433],
[0.499976992607117, 0.719385027885437],
[0.499976992607117, 0.737019002437592],
[0.499967992305756, 0.781370997428894],
[0.499816000461578, 0.562981009483337],
[0.473773002624512, 0.573909997940063],
[0.104906998574734, 0.254140973091125],
[0.365929991006851, 0.409575998783112],
[0.338757991790771, 0.41302502155304],
[0.311120003461838, 0.409460008144379],
[0.274657994508743, 0.389131009578705],
[0.393361985683441, 0.403706014156342],
[0.345234006643295, 0.344011008739471],
[0.370094001293182, 0.346076011657715],
[0.319321990013123, 0.347265005111694],
[0.297903001308441, 0.353591024875641],
[0.24779200553894, 0.410809993743896],
[0.396889001131058, 0.842755019664764],
[0.280097991228104, 0.375599980354309],
[0.106310002505779, 0.399955987930298],
[0.2099249958992, 0.391353011131287],
[0.355807989835739, 0.534406006336212],
[0.471751004457474, 0.65040397644043],
[0.474155008792877, 0.680191993713379],
[0.439785003662109, 0.657229006290436],
[0.414617002010345, 0.66654098033905],
[0.450374007225037, 0.680860996246338],
[0.428770989179611, 0.682690978050232],
[0.374971002340317, 0.727805018424988],
[0.486716985702515, 0.547628998756409],
[0.485300987958908, 0.527395009994507],
[0.257764995098114, 0.314490020275116],
[0.401223003864288, 0.455172002315521],
[0.429818987846375, 0.548614978790283],
[0.421351999044418, 0.533740997314453],
[0.276895999908447, 0.532056987285614],
[0.483370006084442, 0.499586999416351],
[0.33721199631691, 0.282882988452911],
[0.296391993761063, 0.293242990970612],
[0.169294998049736, 0.193813979625702],
[0.447580009698868, 0.302609980106354],
[0.392390012741089, 0.353887975215912],
[0.354490011930466, 0.696784019470215],
[0.067304998636246, 0.730105042457581],
[0.442739009857178, 0.572826027870178],
[0.457098007202148, 0.584792017936707],
[0.381974011659622, 0.694710969924927],
[0.392388999462128, 0.694203019142151],
[0.277076005935669, 0.271932005882263],
[0.422551989555359, 0.563233017921448],
[0.385919004678726, 0.281364023685455],
[0.383103013038635, 0.255840003490448],
[0.331431001424789, 0.119714021682739],
[0.229923993349075, 0.232002973556519],
[0.364500999450684, 0.189113974571228],
[0.229622006416321, 0.299540996551514],
[0.173287004232407, 0.278747975826263],
[0.472878992557526, 0.666198015213013],
[0.446828007698059, 0.668527007102966],
[0.422762006521225, 0.673889994621277],
[0.445307999849319, 0.580065965652466],
[0.388103008270264, 0.693961024284363],
[0.403039008378983, 0.706539988517761],
[0.403629004955292, 0.693953037261963],
[0.460041999816895, 0.557139039039612],
[0.431158006191254, 0.692366003990173],
[0.452181994915009, 0.692366003990173],
[0.475387006998062, 0.692366003990173],
[0.465828001499176, 0.779190003871918],
[0.472328990697861, 0.736225962638855],
[0.473087012767792, 0.717857003211975],
[0.473122000694275, 0.704625964164734],
[0.473033010959625, 0.695277988910675],
[0.427942007780075, 0.695277988910675],
[0.426479011774063, 0.703539967536926],
[0.423162013292313, 0.711845993995667],
[0.4183090031147, 0.720062971115112],
[0.390094995498657, 0.639572978019714],
[0.013953999616206, 0.560034036636353],
[0.499913990497589, 0.58014702796936],
[0.413199990987778, 0.69539999961853],
[0.409626007080078, 0.701822996139526],
[0.468080013990402, 0.601534962654114],
[0.422728985548019, 0.585985004901886],
[0.463079988956451, 0.593783974647522],
[0.37211999297142, 0.47341400384903],
[0.334562003612518, 0.496073007583618],
[0.411671012639999, 0.546965003013611],
[0.242175996303558, 0.14767599105835],
[0.290776997804642, 0.201445996761322],
[0.327338010072708, 0.256527006626129],
[0.399509996175766, 0.748921036720276],
[0.441727995872498, 0.261676013469696],
[0.429764986038208, 0.187834024429321],
[0.412198007106781, 0.108901023864746],
[0.288955003023148, 0.398952007293701],
[0.218936994671822, 0.435410976409912],
[0.41278201341629, 0.398970007896423],
[0.257135003805161, 0.355440020561218],
[0.427684992551804, 0.437960982322693],
[0.448339998722076, 0.536936044692993],
[0.178560003638268, 0.45755398273468],
[0.247308000922203, 0.457193970680237],
[0.286267012357712, 0.467674970626831],
[0.332827985286713, 0.460712015628815],
[0.368755996227264, 0.447206974029541],
[0.398963987827301, 0.432654976844788],
[0.476410001516342, 0.405806005001068],
[0.189241006970406, 0.523923993110657],
[0.228962004184723, 0.348950982093811],
[0.490725994110107, 0.562400996685028],
[0.404670000076294, 0.485132992267609],
[0.019469000399113, 0.401564002037048],
[0.426243007183075, 0.420431017875671],
[0.396993011236191, 0.548797011375427],
[0.266469985246658, 0.376977026462555],
[0.439121007919312, 0.51895797252655],
[0.032313998788595, 0.644356966018677],
[0.419054001569748, 0.387154996395111],
[0.462783008813858, 0.505746960639954],
[0.238978996872902, 0.779744982719421],
[0.198220998048782, 0.831938028335571],
[0.107550002634525, 0.540755033493042],
[0.183610007166862, 0.740257024765015],
[0.134409993886948, 0.333683013916016],
[0.385764002799988, 0.883153975009918],
[0.490967005491257, 0.579378008842468],
[0.382384985685349, 0.508572995662689],
[0.174399003386497, 0.397670984268188],
[0.318785011768341, 0.39623498916626],
[0.343364000320435, 0.400596976280212],
[0.396100014448166, 0.710216999053955],
[0.187885001301765, 0.588537991046906],
[0.430987000465393, 0.944064974784851],
[0.318993002176285, 0.898285031318665],
[0.266247987747192, 0.869701027870178],
[0.500023007392883, 0.190576016902924],
[0.499976992607117, 0.954452991485596],
[0.366169989109039, 0.398822009563446],
[0.393207013607025, 0.39553701877594],
[0.410373002290726, 0.391080021858215],
[0.194993004202843, 0.342101991176605],
[0.388664990663528, 0.362284004688263],
[0.365961998701096, 0.355970978736877],
[0.343364000320435, 0.355356991291046],
[0.318785011768341, 0.35834002494812],
[0.301414996385574, 0.363156020641327],
[0.058132998645306, 0.319076001644135],
[0.301414996385574, 0.387449026107788],
[0.499987989664078, 0.618434011936188],
[0.415838003158569, 0.624195992946625],
[0.445681989192963, 0.566076993942261],
[0.465844005346298, 0.620640993118286],
[0.49992299079895, 0.351523995399475],
[0.288718998432159, 0.819945991039276],
[0.335278987884521, 0.852819979190826],
[0.440512001514435, 0.902418971061707],
[0.128294005990028, 0.791940987110138],
[0.408771991729736, 0.373893976211548],
[0.455606997013092, 0.451801002025604],
[0.499877005815506, 0.908990025520325],
[0.375436991453171, 0.924192011356354],
[0.11421000212431, 0.615022003650665],
[0.448662012815475, 0.695277988910675],
[0.4480200111866, 0.704632043838501],
[0.447111994028091, 0.715808033943176],
[0.444831997156143, 0.730794012546539],
[0.430011987686157, 0.766808986663818],
[0.406787008047104, 0.685672998428345],
[0.400738000869751, 0.681069016456604],
[0.392399996519089, 0.677703022956848],
[0.367855995893478, 0.663918972015381],
[0.247923001646996, 0.601333022117615],
[0.452769994735718, 0.420849978923798],
[0.43639200925827, 0.359887003898621],
[0.416164010763168, 0.368713974952698],
[0.413385987281799, 0.692366003990173],
[0.228018000721931, 0.683571994304657],
[0.468268007040024, 0.352671027183533],
[0.411361992359161, 0.804327011108398],
[0.499989002943039, 0.469825029373169],
[0.479153990745544, 0.442654013633728],
[0.499974012374878, 0.439637005329132],
[0.432112008333206, 0.493588984012604],
[0.499886006116867, 0.866917014122009],
[0.49991300702095, 0.821729004383087],
[0.456548988819122, 0.819200992584229],
[0.344549000263214, 0.745438992977142],
[0.37890899181366, 0.574010014533997],
[0.374292999505997, 0.780184984207153],
[0.319687992334366, 0.570737957954407],
[0.357154995203018, 0.604269981384277],
[0.295284003019333, 0.621580958366394],
[0.447750002145767, 0.862477004528046],
[0.410986006259918, 0.508723020553589],
[0.31395098567009, 0.775308012962341],
[0.354128003120422, 0.812552988529205],
[0.324548006057739, 0.703992962837219],
[0.189096003770828, 0.646299958229065],
[0.279776990413666, 0.71465802192688],
[0.1338230073452, 0.682700991630554],
[0.336768001317978, 0.644733011722565],
[0.429883986711502, 0.466521978378296],
[0.455527991056442, 0.548622965812683],
[0.437114000320435, 0.558896005153656],
[0.467287987470627, 0.529924988746643],
[0.414712011814117, 0.335219979286194],
[0.37704598903656, 0.322777986526489],
[0.344107985496521, 0.320150971412659],
[0.312875986099243, 0.32233202457428],
[0.283526003360748, 0.333190023899078],
[0.241245999932289, 0.382785975933075],
[0.102986000478268, 0.468762993812561],
[0.267612010240555, 0.424560010433197],
[0.297879010438919, 0.433175981044769],
[0.333433985710144, 0.433878004550934],
[0.366427004337311, 0.426115989685059],
[0.396012008190155, 0.416696012020111],
[0.420121014118195, 0.41022801399231],
[0.007561000064015, 0.480777025222778],
[0.432949006557465, 0.569517970085144],
[0.458638995885849, 0.479089021682739],
[0.473466008901596, 0.545744001865387],
[0.476087987422943, 0.563830018043518],
[0.468472003936768, 0.555056989192963],
[0.433990985155106, 0.582361996173859],
[0.483518004417419, 0.562983989715576],
[0.482482999563217, 0.57784903049469],
[0.42645001411438, 0.389798998832703],
[0.438998997211456, 0.39649498462677],
[0.450067013502121, 0.400434017181396],
[0.289712011814117, 0.368252992630005],
[0.276670008897781, 0.363372981548309],
[0.517862021923065, 0.471948027610779],
[0.710287988185883, 0.380764007568359],
[0.526226997375488, 0.573909997940063],
[0.895093023777008, 0.254140973091125],
[0.634069979190826, 0.409575998783112],
[0.661242008209229, 0.41302502155304],
[0.688880026340485, 0.409460008144379],
[0.725341975688934, 0.389131009578705],
[0.606630027294159, 0.40370500087738],
[0.654766023159027, 0.344011008739471],
[0.629905998706818, 0.346076011657715],
[0.680678009986877, 0.347265005111694],
[0.702096998691559, 0.353591024875641],
[0.75221198797226, 0.410804986953735],
[0.602918028831482, 0.842862963676453],
[0.719901978969574, 0.375599980354309],
[0.893692970275879, 0.399959981441498],
[0.790081977844238, 0.391354024410248],
[0.643998026847839, 0.534487962722778],
[0.528249025344849, 0.65040397644043],
[0.525849997997284, 0.680191040039062],
[0.560214996337891, 0.657229006290436],
[0.585384011268616, 0.66654098033905],
[0.549625992774963, 0.680860996246338],
[0.57122802734375, 0.682691991329193],
[0.624852001667023, 0.72809898853302],
[0.513050019741058, 0.547281980514526],
[0.51509702205658, 0.527251958847046],
[0.742246985435486, 0.314507007598877],
[0.598631024360657, 0.454979002475739],
[0.570338010787964, 0.548575043678284],
[0.578631997108459, 0.533622980117798],
[0.723087012767792, 0.532054007053375],
[0.516445994377136, 0.499638974666595],
[0.662801027297974, 0.282917976379395],
[0.70362401008606, 0.293271005153656],
[0.830704987049103, 0.193813979625702],
[0.552385985851288, 0.302568018436432],
[0.607609987258911, 0.353887975215912],
[0.645429015159607, 0.696707010269165],
[0.932694971561432, 0.730105042457581],
[0.557260990142822, 0.572826027870178],
[0.542901992797852, 0.584792017936707],
[0.6180260181427, 0.694710969924927],
[0.607590973377228, 0.694203019142151],
[0.722943007946014, 0.271963000297546],
[0.577413976192474, 0.563166975975037],
[0.614082992076874, 0.281386971473694],
[0.616907000541687, 0.255886018276215],
[0.668509006500244, 0.119913995265961],
[0.770092010498047, 0.232020974159241],
[0.635536015033722, 0.189248979091644],
[0.77039098739624, 0.299556016921997],
[0.826722025871277, 0.278755009174347],
[0.527121007442474, 0.666198015213013],
[0.553171992301941, 0.668527007102966],
[0.577238023281097, 0.673889994621277],
[0.554691970348358, 0.580065965652466],
[0.611896991729736, 0.693961024284363],
[0.59696102142334, 0.706539988517761],
[0.596370995044708, 0.693953037261963],
[0.539958000183105, 0.557139039039612],
[0.568841993808746, 0.692366003990173],
[0.547818005084991, 0.692366003990173],
[0.52461302280426, 0.692366003990173],
[0.534089982509613, 0.779141008853912],
[0.527670979499817, 0.736225962638855],
[0.526912987232208, 0.717857003211975],
[0.526877999305725, 0.704625964164734],
[0.526966989040375, 0.695277988910675],
[0.572058022022247, 0.695277988910675],
[0.573521018028259, 0.703539967536926],
[0.57683801651001, 0.711845993995667],
[0.581691026687622, 0.720062971115112],
[0.609944999217987, 0.639909982681274],
[0.986046016216278, 0.560034036636353],
[0.5867999792099, 0.69539999961853],
[0.590372025966644, 0.701822996139526],
[0.531915009021759, 0.601536989212036],
[0.577268004417419, 0.585934996604919],
[0.536915004253387, 0.593786001205444],
[0.627542972564697, 0.473352015018463],
[0.665585994720459, 0.495950996875763],
[0.588353991508484, 0.546862006187439],
[0.757824003696442, 0.14767599105835],
[0.709249973297119, 0.201507985591888],
[0.672684013843536, 0.256581008434296],
[0.600408971309662, 0.74900496006012],
[0.55826598405838, 0.261672019958496],
[0.570303976535797, 0.187870979309082],
[0.588165998458862, 0.109044015407562],
[0.711045026779175, 0.398952007293701],
[0.781069993972778, 0.435405015945435],
[0.587247014045715, 0.398931980133057],
[0.742869973182678, 0.355445981025696],
[0.572156012058258, 0.437651991844177],
[0.55186802148819, 0.536570012569427],
[0.821442008018494, 0.457556009292603],
[0.752701997756958, 0.457181990146637],
[0.71375697851181, 0.467626988887787],
[0.66711300611496, 0.460672974586487],
[0.631101012229919, 0.447153985500336],
[0.6008620262146, 0.432473003864288],
[0.523481011390686, 0.405627012252808],
[0.810747981071472, 0.523926019668579],
[0.771045982837677, 0.348959028720856],
[0.509127020835876, 0.562718033790588],
[0.595292985439301, 0.485023975372314],
[0.980530977249146, 0.401564002037048],
[0.573499977588654, 0.420000016689301],
[0.602994978427887, 0.548687994480133],
[0.733529984951019, 0.376977026462555],
[0.560611009597778, 0.519016981124878],
[0.967685997486115, 0.644356966018677],
[0.580985009670258, 0.387160003185272],
[0.537728011608124, 0.505385041236877],
[0.760966002941132, 0.779752969741821],
[0.801778972148895, 0.831938028335571],
[0.892440974712372, 0.54076099395752],
[0.816350996494293, 0.740260004997253],
[0.865594983100891, 0.333687007427216],
[0.614073991775513, 0.883246004581451],
[0.508952975273132, 0.579437971115112],
[0.617941975593567, 0.508316040039062],
[0.825608015060425, 0.397674977779388],
[0.681214988231659, 0.39623498916626],
[0.656635999679565, 0.400596976280212],
[0.603900015354156, 0.710216999053955],
[0.81208598613739, 0.588539004325867],
[0.56801301240921, 0.944564998149872],
[0.681007981300354, 0.898285031318665],
[0.733752012252808, 0.869701027870178],
[0.633830010890961, 0.398822009563446],
[0.606792986392975, 0.39553701877594],
[0.589659988880157, 0.391062021255493],
[0.805015981197357, 0.342108011245728],
[0.611334979534149, 0.362284004688263],
[0.634037971496582, 0.355970978736877],
[0.656635999679565, 0.355356991291046],
[0.681214988231659, 0.35834002494812],
[0.698584973812103, 0.363156020641327],
[0.941866993904114, 0.319076001644135],
[0.698584973812103, 0.387449026107788],
[0.584177017211914, 0.624107003211975],
[0.554318010807037, 0.566076993942261],
[0.534153997898102, 0.62064003944397],
[0.711217999458313, 0.819975018501282],
[0.664629995822906, 0.852871000766754],
[0.559099972248077, 0.902631998062134],
[0.871706008911133, 0.791940987110138],
[0.591234028339386, 0.373893976211548],
[0.544341027736664, 0.451583981513977],
[0.624562978744507, 0.924192011356354],
[0.88577002286911, 0.615028977394104],
[0.551338016986847, 0.695277988910675],
[0.551980018615723, 0.704632043838501],
[0.552887976169586, 0.715808033943176],
[0.555167973041534, 0.730794012546539],
[0.569944024085999, 0.767035007476807],
[0.593203008174896, 0.685675978660583],
[0.599261999130249, 0.681069016456604],
[0.607599973678589, 0.677703022956848],
[0.631937980651855, 0.663500010967255],
[0.752032995223999, 0.601315021514893],
[0.547226011753082, 0.420395016670227],
[0.563543975353241, 0.359827995300293],
[0.583841025829315, 0.368713974952698],
[0.586614012718201, 0.692366003990173],
[0.771915018558502, 0.683578014373779],
[0.531597018241882, 0.352482974529266],
[0.588370978832245, 0.804440975189209],
[0.52079701423645, 0.442565023899078],
[0.567984998226166, 0.493479013442993],
[0.543282985687256, 0.819254994392395],
[0.655317008495331, 0.745514988899231],
[0.621008992195129, 0.574018001556396],
[0.625559985637665, 0.78031200170517],
[0.680198013782501, 0.570719003677368],
[0.64276397228241, 0.604337990283966],
[0.704662978649139, 0.621529996395111],
[0.552012026309967, 0.862591981887817],
[0.589071989059448, 0.508637011051178],
[0.685944974422455, 0.775357007980347],
[0.645735025405884, 0.812640011310577],
[0.675342977046967, 0.703978002071381],
[0.810858011245728, 0.646304965019226],
[0.72012197971344, 0.714666962623596],
[0.866151988506317, 0.682704985141754],
[0.663187026977539, 0.644596993923187],
[0.570082008838654, 0.466325998306274],
[0.544561982154846, 0.548375964164734],
[0.562758982181549, 0.558784961700439],
[0.531987011432648, 0.530140042304993],
[0.585271000862122, 0.335177004337311],
[0.622952997684479, 0.32277899980545],
[0.655896008014679, 0.320163011550903],
[0.687132000923157, 0.322345972061157],
[0.716481983661652, 0.333200991153717],
[0.758756995201111, 0.382786989212036],
[0.897013008594513, 0.468769013881683],
[0.732392013072968, 0.424547016620636],
[0.70211398601532, 0.433162987232208],
[0.66652500629425, 0.433866024017334],
[0.633504986763, 0.426087975502014],
[0.603875994682312, 0.416586995124817],
[0.579657971858978, 0.409945011138916],
[0.992439985275269, 0.480777025222778],
[0.567192018032074, 0.569419980049133],
[0.54136598110199, 0.478899002075195],
[0.526564002037048, 0.546118021011353],
[0.523913025856018, 0.563830018043518],
[0.531529009342194, 0.555056989192963],
[0.566035985946655, 0.582329034805298],
[0.51631098985672, 0.563053965568542],
[0.5174720287323, 0.577877044677734],
[0.573594987392426, 0.389806985855103],
[0.560697972774506, 0.395331978797913],
[0.549755990505219, 0.399751007556915],
[0.710287988185883, 0.368252992630005],
[0.723330020904541, 0.363372981548309]
];
});
// src/facemesh/triangulation.js
var require_triangulation = __commonJS((exports) => {
__export(exports, {
default: () => triangulation_default
});
var triangulation_default = [
127,
34,
139,
11,
0,
37,
232,
231,
120,
72,
37,
39,
128,
121,
47,
232,
121,
128,
104,
69,
67,
175,
171,
148,
157,
154,
155,
118,
50,
101,
73,
39,
40,
9,
151,
108,
48,
115,
131,
194,
204,
211,
74,
40,
185,
80,
42,
183,
40,
92,
186,
230,
229,
118,
202,
212,
214,
83,
18,
17,
76,
61,
146,
160,
29,
30,
56,
157,
173,
106,
204,
194,
135,
214,
192,
203,
165,
98,
21,
71,
68,
51,
45,
4,
144,
24,
23,
77,
146,
91,
205,
50,
187,
201,
200,
18,
91,
106,
182,
90,
91,
181,
85,
84,
17,
206,
203,
36,
148,
171,
140,
92,
40,
39,
193,
189,
244,
159,
158,
28,
247,
246,
161,
236,
3,
196,
54,
68,
104,
193,
168,
8,
117,
228,
31,
189,
193,
55,
98,
97,
99,
126,
47,
100,
166,
79,
218,
155,
154,
26,
209,
49,
131,
135,
136,
150,
47,
126,
217,
223,
52,
53,
45,
51,
134,
211,
170,
140,
67,
69,
108,
43,
106,
91,
230,
119,
120,
226,
130,
247,
63,
53,
52,
238,
20,
242,
46,
70,
156,
78,
62,
96,
46,
53,
63,
143,
34,
227,
173,
155,
133,
123,
117,
111,
44,
125,
19,
236,
134,
51,
216,
206,
205,
154,
153,
22,
39,
37,
167,
200,
201,
208,
36,
142,
100,
57,
212,
202,
20,
60,
99,
28,
158,
157,
35,
226,
113,
160,
159,
27,
204,
202,
210,
113,
225,
46,
43,
202,
204,
62,
76,
77,
137,
123,
116,
41,
38,
72,
203,
129,
142,
64,
98,
240,
49,
102,
64,
41,
73,
74,
212,
216,
207,
42,
74,
184,
169,
170,
211,
170,
149,
176,
105,
66,
69,
122,
6,
168,
123,
147,
187,
96,
77,
90,
65,
55,
107,
89,
90,
180,
101,
100,
120,
63,
105,
104,
93,
137,
227,
15,
86,
85,
129,
102,
49,
14,
87,
86,
55,
8,
9,
100,
47,
121,
145,
23,
22,
88,
89,
179,
6,
122,
196,
88,
95,
96,
138,
172,
136,
215,
58,
172,
115,
48,
219,
42,
80,
81,
195,
3,
51,
43,
146,
61,
171,
175,
199,
81,
82,
38,
53,
46,
225,
144,
163,
110,
246,
33,
7,
52,
65,
66,
229,
228,
117,
34,
127,
234,
107,
108,
69,
109,
108,
151,
48,
64,
235,
62,
78,
191,
129,
209,
126,
111,
35,
143,
163,
161,
246,
117,
123,
50,
222,
65,
52,
19,
125,
141,
221,
55,
65,
3,
195,
197,
25,
7,
33,
220,
237,
44,
70,
71,
139,
122,
193,
245,
247,
130,
33,
71,
21,
162,
153,
158,
159,
170,
169,
150,
188,
174,
196,
216,
186,
92,
144,
160,
161,
2,
97,
167,
141,
125,
241,
164,
167,
37,
72,
38,
12,
145,
159,
160,
38,
82,
13,
63,
68,
71,
226,
35,
111,
158,
153,
154,
101,
50,
205,
206,
92,
165,
209,
198,
217,
165,
167,
97,
220,
115,
218,
133,
112,
243,
239,
238,
241,
214,
135,
169,
190,
173,
133,
171,
208,
32,
125,
44,
237,
86,
87,
178,
85,
86,
179,
84,
85,
180,
83,
84,
181,
201,
83,
182,
137,
93,
132,
76,
62,
183,
61,
76,
184,
57,
61,
185,
212,
57,
186,
214,
207,
187,
34,
143,
156,
79,
239,
237,
123,
137,
177,
44,
1,
4,
201,
194,
32,
64,
102,
129,
213,
215,
138,
59,
166,
219,
242,
99,
97,
2,
94,
141,
75,
59,
235,
24,
110,
228,
25,
130,
226,
23,
24,
229,
22,
23,
230,
26,
22,
231,
112,
26,
232,
189,
190,
243,
221,
56,
190,
28,
56,
221,
27,
28,
222,
29,
27,
223,
30,
29,
224,
247,
30,
225,
238,
79,
20,
166,
59,
75,
60,
75,
240,
147,
177,
215,
20,
79,
166,
187,
147,
213,
112,
233,
244,
233,
128,
245,
128,
114,
188,
114,
217,
174,
131,
115,
220,
217,
198,
236,
198,
131,
134,
177,
132,
58,
143,
35,
124,
110,
163,
7,
228,
110,
25,
356,
389,
368,
11,
302,
267,
452,
350,
349,
302,
303,
269,
357,
343,
277,
452,
453,
357,
333,
332,
297,
175,
152,
377,
384,
398,
382,
347,
348,
330,
303,
304,
270,
9,
336,
337,
278,
279,
360,
418,
262,
431,
304,
408,
409,
310,
415,
407,
270,
409,
410,
450,
348,
347,
422,
430,
434,
313,
314,
17,
306,
307,
375,
387,
388,
260,
286,
414,
398,
335,
406,
418,
364,
367,
416,
423,
358,
327,
251,
284,
298,
281,
5,
4,
373,
374,
253,
307,
320,
321,
425,
427,
411,
421,
313,
18,
321,
405,
406,
320,
404,
405,
315,
16,
17,
426,
425,
266,
377,
400,
369,
322,
391,
269,
417,
465,
464,
386,
257,
258,
466,
260,
388,
456,
399,
419,
284,
332,
333,
417,
285,
8,
346,
340,
261,
413,
441,
285,
327,
460,
328,
355,
371,
329,
392,
439,
438,
382,
341,
256,
429,
420,
360,
364,
394,
379,
277,
343,
437,
443,
444,
283,
275,
440,
363,
431,
262,
369,
297,
338,
337,
273,
375,
321,
450,
451,
349,
446,
342,
467,
293,
334,
282,
458,
461,
462,
276,
353,
383,
308,
324,
325,
276,
300,
293,
372,
345,
447,
382,
398,
362,
352,
345,
340,
274,
1,
19,
456,
248,
281,
436,
427,
425,
381,
256,
252,
269,
391,
393,
200,
199,
428,
266,
330,
329,
287,
273,
422,
250,
462,
328,
258,
286,
384,
265,
353,
342,
387,
259,
257,
424,
431,
430,
342,
353,
276,
273,
335,
424,
292,
325,
307,
366,
447,
345,
271,
303,
302,
423,
266,
371,
294,
455,
460,
279,
278,
294,
271,
272,
304,
432,
434,
427,
272,
407,
408,
394,
430,
431,
395,
369,
400,
334,
333,
299,
351,
417,
168,
352,
280,
411,
325,
319,
320,
295,
296,
336,
319,
403,
404,
330,
348,
349,
293,
298,
333,
323,
454,
447,
15,
16,
315,
358,
429,
279,
14,
15,
316,
285,
336,
9,
329,
349,
350,
374,
380,
252,
318,
402,
403,
6,
197,
419,
318,
319,
325,
367,
364,
365,
435,
367,
397,
344,
438,
439,
272,
271,
311,
195,
5,
281,
273,
287,
291,
396,
428,
199,
311,
271,
268,
283,
444,
445,
373,
254,
339,
263,
466,
249,
282,
334,
296,
449,
347,
346,
264,
447,
454,
336,
296,
299,
338,
10,
151,
278,
439,
455,
292,
407,
415,
358,
371,
355,
340,
345,
372,
390,
249,
466,
346,
347,
280,
442,
443,
282,
19,
94,
370,
441,
442,
295,
248,
419,
197,
263,
255,
359,
440,
275,
274,
300,
383,
368,
351,
412,
465,
263,
467,
466,
301,
368,
389,
380,
374,
386,
395,
378,
379,
412,
351,
419,
436,
426,
322,
373,
390,
388,
2,
164,
393,
370,
462,
461,
164,
0,
267,
302,
11,
12,
374,
373,
387,
268,
12,
13,
293,
300,
301,
446,
261,
340,
385,
384,
381,
330,
266,
425,
426,
423,
391,
429,
355,
437,
391,
327,
326,
440,
457,
438,
341,
382,
362,
459,
457,
461,
434,
430,
394,
414,
463,
362,
396,
369,
262,
354,
461,
457,
316,
403,
402,
315,
404,
403,
314,
405,
404,
313,
406,
405,
421,
418,
406,
366,
401,
361,
306,
408,
407,
291,
409,
408,
287,
410,
409,
432,
436,
410,
434,
416,
411,
264,
368,
383,
309,
438,
457,
352,
376,
401,
274,
275,
4,
421,
428,
262,
294,
327,
358,
433,
416,
367,
289,
455,
439,
462,
370,
326,
2,
326,
370,
305,
460,
455,
254,
449,
448,
255,
261,
446,
253,
450,
449,
252,
451,
450,
256,
452,
451,
341,
453,
452,
413,
464,
463,
441,
413,
414,
258,
442,
441,
257,
443,
442,
259,
444,
443,
260,
445,
444,
467,
342,
445,
459,
458,
250,
289,
392,
290,
290,
328,
460,
376,
433,
435,
250,
290,
392,
411,
416,
433,
341,
463,
464,
453,
464,
465,
357,
465,
412,
343,
412,
399,
360,
363,
440,
437,
399,
456,
420,
456,
363,
401,
435,
288,
372,
383,
353,
339,
255,
249,
448,
261,
255,
133,
243,
190,
133,
155,
112,
33,
246,
247,
33,
130,
25,
398,
384,
286,
362,
398,
414,
362,
463,
341,
263,
359,
467,
263,
249,
255,
466,
467,
260,
75,
60,
166,
238,
239,
79,
162,
127,
139,
72,
11,
37,
121,
232,
120,
73,
72,
39,
114,
128,
47,
233,
232,
128,
103,
104,
67,
152,
175,
148,
173,
157,
155,
119,
118,
101,
74,
73,
40,
107,
9,
108,
49,
48,
131,
32,
194,
211,
184,
74,
185,
191,
80,
183,
185,
40,
186,
119,
230,
118,
210,
202,
214,
84,
83,
17,
77,
76,
146,
161,
160,
30,
190,
56,
173,
182,
106,
194,
138,
135,
192,
129,
203,
98,
54,
21,
68,
5,
51,
4,
145,
144,
23,
90,
77,
91,
207,
205,
187,
83,
201,
18,
181,
91,
182,
180,
90,
181,
16,
85,
17,
205,
206,
36,
176,
148,
140,
165,
92,
39,
245,
193,
244,
27,
159,
28,
30,
247,
161,
174,
236,
196,
103,
54,
104,
55,
193,
8,
111,
117,
31,
221,
189,
55,
240,
98,
99,
142,
126,
100,
219,
166,
218,
112,
155,
26,
198,
209,
131,
169,
135,
150,
114,
47,
217,
224,
223,
53,
220,
45,
134,
32,
211,
140,
109,
67,
108,
146,
43,
91,
231,
230,
120,
113,
226,
247,
105,
63,
52,
241,
238,
242,
124,
46,
156,
95,
78,
96,
70,
46,
63,
116,
143,
227,
116,
123,
111,
1,
44,
19,
3,
236,
51,
207,
216,
205,
26,
154,
22,
165,
39,
167,
199,
200,
208,
101,
36,
100,
43,
57,
202,
242,
20,
99,
56,
28,
157,
124,
35,
113,
29,
160,
27,
211,
204,
210,
124,
113,
46,
106,
43,
204,
96,
62,
77,
227,
137,
116,
73,
41,
72,
36,
203,
142,
235,
64,
240,
48,
49,
64,
42,
41,
74,
214,
212,
207,
183,
42,
184,
210,
169,
211,
140,
170,
176,
104,
105,
69,
193,
122,
168,
50,
123,
187,
89,
96,
90,
66,
65,
107,
179,
89,
180,
119,
101,
120,
68,
63,
104,
234,
93,
227,
16,
15,
85,
209,
129,
49,
15,
14,
86,
107,
55,
9,
120,
100,
121,
153,
145,
22,
178,
88,
179,
197,
6,
196,
89,
88,
96,
135,
138,
136,
138,
215,
172,
218,
115,
219,
41,
42,
81,
5,
195,
51,
57,
43,
61,
208,
171,
199,
41,
81,
38,
224,
53,
225,
24,
144,
110,
105,
52,
66,
118,
229,
117,
227,
34,
234,
66,
107,
69,
10,
109,
151,
219,
48,
235,
183,
62,
191,
142,
129,
126,
116,
111,
143,
7,
163,
246,
118,
117,
50,
223,
222,
52,
94,
19,
141,
222,
221,
65,
196,
3,
197,
45,
220,
44,
156,
70,
139,
188,
122,
245,
139,
71,
162,
145,
153,
159,
149,
170,
150,
122,
188,
196,
206,
216,
92,
163,
144,
161,
164,
2,
167,
242,
141,
241,
0,
164,
37,
11,
72,
12,
144,
145,
160,
12,
38,
13,
70,
63,
71,
31,
226,
111,
157,
158,
154,
36,
101,
205,
203,
206,
165,
126,
209,
217,
98,
165,
97,
237,
220,
218,
237,
239,
241,
210,
214,
169,
140,
171,
32,
241,
125,
237,
179,
86,
178,
180,
85,
179,
181,
84,
180,
182,
83,
181,
194,
201,
182,
177,
137,
132,
184,
76,
183,
185,
61,
184,
186,
57,
185,
216,
212,
186,
192,
214,
187,
139,
34,
156,
218,
79,
237,
147,
123,
177,
45,
44,
4,
208,
201,
32,
98,
64,
129,
192,
213,
138,
235,
59,
219,
141,
242,
97,
97,
2,
141,
240,
75,
235,
229,
24,
228,
31,
25,
226,
230,
23,
229,
231,
22,
230,
232,
26,
231,
233,
112,
232,
244,
189,
243,
189,
221,
190,
222,
28,
221,
223,
27,
222,
224,
29,
223,
225,
30,
224,
113,
247,
225,
99,
60,
240,
213,
147,
215,
60,
20,
166,
192,
187,
213,
243,
112,
244,
244,
233,
245,
245,
128,
188,
188,
114,
174,
134,
131,
220,
174,
217,
236,
236,
198,
134,
215,
177,
58,
156,
143,
124,
25,
110,
7,
31,
228,
25,
264,
356,
368,
0,
11,
267,
451,
452,
349,
267,
302,
269,
350,
357,
277,
350,
452,
357,
299,
333,
297,
396,
175,
377,
381,
384,
382,
280,
347,
330,
269,
303,
270,
151,
9,
337,
344,
278,
360,
424,
418,
431,
270,
304,
409,
272,
310,
407,
322,
270,
410,
449,
450,
347,
432,
422,
434,
18,
313,
17,
291,
306,
375,
259,
387,
260,
424,
335,
418,
434,
364,
416,
391,
423,
327,
301,
251,
298,
275,
281,
4,
254,
373,
253,
375,
307,
321,
280,
425,
411,
200,
421,
18,
335,
321,
406,
321,
320,
405,
314,
315,
17,
423,
426,
266,
396,
377,
369,
270,
322,
269,
413,
417,
464,
385,
386,
258,
248,
456,
419,
298,
284,
333,
168,
417,
8,
448,
346,
261,
417,
413,
285,
326,
327,
328,
277,
355,
329,
309,
392,
438,
381,
382,
256,
279,
429,
360,
365,
364,
379,
355,
277,
437,
282,
443,
283,
281,
275,
363,
395,
431,
369,
299,
297,
337,
335,
273,
321,
348,
450,
349,
359,
446,
467,
283,
293,
282,
250,
458,
462,
300,
276,
383,
292,
308,
325,
283,
276,
293,
264,
372,
447,
346,
352,
340,
354,
274,
19,
363,
456,
281,
426,
436,
425,
380,
381,
252,
267,
269,
393,
421,
200,
428,
371,
266,
329,
432,
287,
422,
290,
250,
328,
385,
258,
384,
446,
265,
342,
386,
387,
257,
422,
424,
430,
445,
342,
276,
422,
273,
424,
306,
292,
307,
352,
366,
345,
268,
271,
302,
358,
423,
371,
327,
294,
460,
331,
279,
294,
303,
271,
304,
436,
432,
427,
304,
272,
408,
395,
394,
431,
378,
395,
400,
296,
334,
299,
6,
351,
168,
376,
352,
411,
307,
325,
320,
285,
295,
336,
320,
319,
404,
329,
330,
349,
334,
293,
333,
366,
323,
447,
316,
15,
315,
331,
358,
279,
317,
14,
316,
8,
285,
9,
277,
329,
350,
253,
374,
252,
319,
318,
403,
351,
6,
419,
324,
318,
325,
397,
367,
365,
288,
435,
397,
278,
344,
439,
310,
272,
311,
248,
195,
281,
375,
273,
291,
175,
396,
199,
312,
311,
268,
276,
283,
445,
390,
373,
339,
295,
282,
296,
448,
449,
346,
356,
264,
454,
337,
336,
299,
337,
338,
151,
294,
278,
455,
308,
292,
415,
429,
358,
355,
265,
340,
372,
388,
390,
466,
352,
346,
280,
295,
442,
282,
354,
19,
370,
285,
441,
295,
195,
248,
197,
457,
440,
274,
301,
300,
368,
417,
351,
465,
251,
301,
389,
385,
380,
386,
394,
395,
379,
399,
412,
419,
410,
436,
322,
387,
373,
388,
326,
2,
393,
354,
370,
461,
393,
164,
267,
268,
302,
12,
386,
374,
387,
312,
268,
13,
298,
293,
301,
265,
446,
340,
380,
385,
381,
280,
330,
425,
322,
426,
391,
420,
429,
437,
393,
391,
326,
344,
440,
438,
458,
459,
461,
364,
434,
394,
428,
396,
262,
274,
354,
457,
317,
316,
402,
316,
315,
403,
315,
314,
404,
314,
313,
405,
313,
421,
406,
323,
366,
361,
292,
306,
407,
306,
291,
408,
291,
287,
409,
287,
432,
410,
427,
434,
411,
372,
264,
383,
459,
309,
457,
366,
352,
401,
1,
274,
4,
418,
421,
262,
331,
294,
358,
435,
433,
367,
392,
289,
439,
328,
462,
326,
94,
2,
370,
289,
305,
455,
339,
254,
448,
359,
255,
446,
254,
253,
449,
253,
252,
450,
252,
256,
451,
256,
341,
452,
414,
413,
463,
286,
441,
414,
286,
258,
441,
258,
257,
442,
257,
259,
443,
259,
260,
444,
260,
467,
445,
309,
459,
250,
305,
289,
290,
305,
290,
460,
401,
376,
435,
309,
250,
392,
376,
411,
433,
453,
341,
464,
357,
453,
465,
343,
357,
412,
437,
343,
399,
344,
360,
440,
420,
437,
456,
360,
420,
363,
361,
401,
288,
265,
372,
353,
390,
339,
249,
339,
448,
255
];
});
// src/facemesh/facemesh.js
var require_facemesh = __commonJS((exports) => {
const tf = require_tf_node();
const blazeface = require_blazeface();
const keypoints = require_keypoints();
const pipe = require_pipeline();
const uv_coords = require_uvcoords();
const triangulation = require_triangulation().default;
class MediaPipeFaceMesh {
constructor(blazeFace, blazeMeshModel, irisModel, config) {
this.pipeline = new pipe.Pipeline(blazeFace, blazeMeshModel, irisModel, config);
if (config)
this.config = config;
}
async estimateFaces(input, config) {
if (config)
this.config = config;
const predictions = await this.pipeline.predict(input, config);
const results = [];
for (const prediction of predictions || []) {
if (prediction.isDisposedInternal)
continue;
const confidence = prediction.confidence.arraySync();
if (confidence >= this.config.detector.minConfidence) {
const mesh = prediction.coords ? prediction.coords.arraySync() : null;
const annotations = {};
if (mesh && mesh.length > 0) {
for (const key in keypoints.MESH_ANNOTATIONS) {
if (this.config.iris.enabled || key.includes("Iris") === false) {
annotations[key] = keypoints.MESH_ANNOTATIONS[key].map((index) => mesh[index]);
}
}
}
results.push({
confidence: confidence || 0,
box: prediction.box ? [prediction.box.startPoint[0], prediction.box.startPoint[1], prediction.box.endPoint[0] - prediction.box.startPoint[0], prediction.box.endPoint[1] - prediction.box.startPoint[1]] : 0,
mesh,
annotations,
image: prediction.image ? tf.clone(prediction.image) : null
});
}
if (prediction.confidence)
prediction.confidence.dispose();
if (prediction.coords)
prediction.coords.dispose();
if (prediction.image)
prediction.image.dispose();
}
return results;
}
}
async function load(config) {
const models = await Promise.all([
blazeface.load(config),
tf.loadGraphModel(config.mesh.modelPath, {fromTFHub: config.mesh.modelPath.includes("tfhub.dev")}),
tf.loadGraphModel(config.iris.modelPath, {fromTFHub: config.iris.modelPath.includes("tfhub.dev")})
]);
const faceMesh = new MediaPipeFaceMesh(models[0], models[1], models[2], config);
return faceMesh;
}
exports.load = load;
exports.MediaPipeFaceMesh = MediaPipeFaceMesh;
exports.uv_coords = uv_coords;
exports.triangulation = triangulation;
});
// src/ssrnet/ssrnet.js
var require_ssrnet = __commonJS((exports) => {
const tf = require_tf_node();
const models = {};
let last = {age: 0, gender: ""};
let frame = 0;
async function loadAge(config) {
if (!models.age)
models.age = await tf.loadGraphModel(config.face.age.modelPath);
return models.age;
}
async function loadGender(config) {
if (!models.gender)
models.gender = await tf.loadGraphModel(config.face.gender.modelPath);
return models.gender;
}
async function predict(image, config) {
if (frame < config.face.age.skipFrames) {
frame += 1;
return last;
}
frame = 0;
const resize = tf.image.resizeBilinear(image, [config.face.age.inputSize, config.face.age.inputSize], false);
const enhance = tf.mul(resize, [255]);
tf.dispose(resize);
const promises = [];
let ageT;
let genderT;
if (config.face.age.enabled)
promises.push(ageT = models.age.predict(enhance));
if (config.face.gender.enabled)
promises.push(genderT = models.gender.predict(enhance));
await Promise.all(promises);
const obj = {};
if (ageT) {
const data = await ageT.data();
obj.age = Math.trunc(10 * data[0]) / 10;
tf.dispose(ageT);
}
if (genderT) {
const data = await genderT.data();
const confidence = Math.trunc(Math.abs(1.9 * 100 * (data[0] - 0.5))) / 100;
if (confidence > config.face.gender.minConfidence) {
obj.gender = data[0] <= 0.5 ? "female" : "male";
obj.confidence = confidence;
}
tf.dispose(genderT);
}
tf.dispose(enhance);
last = obj;
return obj;
}
exports.predict = predict;
exports.loadAge = loadAge;
exports.loadGender = loadGender;
});
// src/emotion/emotion.js
var require_emotion = __commonJS((exports) => {
const tf = require_tf_node();
const annotations = ["angry", "discust", "fear", "happy", "sad", "surpise", "neutral"];
const models = {};
let last = [];
let frame = 0;
const multiplier = 1.5;
async function load(config) {
if (!models.emotion)
models.emotion = await tf.loadGraphModel(config.face.emotion.modelPath);
return models.emotion;
}
async function predict(image, config) {
if (frame < config.face.emotion.skipFrames) {
frame += 1;
return last;
}
frame = 0;
const resize = tf.image.resizeBilinear(image, [config.face.emotion.inputSize, config.face.emotion.inputSize], false);
const [red, green, blue] = tf.split(resize, 3, 3);
resize.dispose();
const redNorm = tf.mul(red, [0.2989]);
const greenNorm = tf.mul(green, [0.587]);
const blueNorm = tf.mul(blue, [0.114]);
red.dispose();
green.dispose();
blue.dispose();
const grayscale = tf.addN([redNorm, greenNorm, blueNorm]);
redNorm.dispose();
greenNorm.dispose();
blueNorm.dispose();
const obj = [];
if (config.face.emotion.enabled) {
const emotionT = await models.emotion.predict(grayscale);
const data = await emotionT.data();
for (let i = 0; i < data.length; i++) {
if (multiplier * data[i] > config.face.emotion.minConfidence)
obj.push({score: Math.min(0.99, Math.trunc(100 * multiplier * data[i]) / 100), emotion: annotations[i]});
}
obj.sort((a, b) => b.score - a.score);
tf.dispose(emotionT);
}
tf.dispose(grayscale);
last = obj;
return obj;
}
exports.predict = predict;
exports.load = load;
});
// src/posenet/modelBase.js
var require_modelBase = __commonJS((exports) => {
const tf = require_tf_node();
class BaseModel {
constructor(model, outputStride) {
this.model = model;
this.outputStride = outputStride;
const inputShape = this.model.inputs[0].shape;
tf.util.assert(inputShape[1] === -1 && inputShape[2] === -1, () => `Input shape [${inputShape[1]}, ${inputShape[2]}] must both be equal to or -1`);
}
predict(input) {
return tf.tidy(() => {
const asFloat = this.preprocessInput(input.toFloat());
const asBatch = asFloat.expandDims(0);
const results = this.model.predict(asBatch);
const results3d = results.map((y) => y.squeeze([0]));
const namedResults = this.nameOutputResults(results3d);
return {
heatmapScores: namedResults.heatmap.sigmoid(),
offsets: namedResults.offsets,
displacementFwd: namedResults.displacementFwd,
displacementBwd: namedResults.displacementBwd
};
});
}
dispose() {
this.model.dispose();
}
}
exports.BaseModel = BaseModel;
});
// src/posenet/modelMobileNet.js
var require_modelMobileNet = __commonJS((exports) => {
const tf = require_tf_node();
const modelBase = require_modelBase();
class MobileNet extends modelBase.BaseModel {
preprocessInput(input) {
return tf.tidy(() => tf.div(input, 127.5).sub(1));
}
nameOutputResults(results) {
const [offsets, heatmap, displacementFwd, displacementBwd] = results;
return {offsets, heatmap, displacementFwd, displacementBwd};
}
}
exports.MobileNet = MobileNet;
});
// src/posenet/heapSort.js
var require_heapSort = __commonJS((exports) => {
function half(k) {
return Math.floor(k / 2);
}
class MaxHeap {
constructor(maxSize, getElementValue) {
this.priorityQueue = new Array(maxSize);
this.numberOfElements = -1;
this.getElementValue = getElementValue;
}
enqueue(x) {
this.priorityQueue[++this.numberOfElements] = x;
this.swim(this.numberOfElements);
}
dequeue() {
const max = this.priorityQueue[0];
this.exchange(0, this.numberOfElements--);
this.sink(0);
this.priorityQueue[this.numberOfElements + 1] = null;
return max;
}
empty() {
return this.numberOfElements === -1;
}
size() {
return this.numberOfElements + 1;
}
all() {
return this.priorityQueue.slice(0, this.numberOfElements + 1);
}
max() {
return this.priorityQueue[0];
}
swim(k) {
while (k > 0 && this.less(half(k), k)) {
this.exchange(k, half(k));
k = half(k);
}
}
sink(k) {
while (2 * k <= this.numberOfElements) {
let j = 2 * k;
if (j < this.numberOfElements && this.less(j, j + 1))
j++;
if (!this.less(k, j))
break;
this.exchange(k, j);
k = j;
}
}
getValueAt(i) {
return this.getElementValue(this.priorityQueue[i]);
}
less(i, j) {
return this.getValueAt(i) < this.getValueAt(j);
}
exchange(i, j) {
const t = this.priorityQueue[i];
this.priorityQueue[i] = this.priorityQueue[j];
this.priorityQueue[j] = t;
}
}
exports.MaxHeap = MaxHeap;
});
// src/posenet/buildParts.js
var require_buildParts = __commonJS((exports) => {
const heapSort = require_heapSort();
function scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, localMaximumRadius, scores) {
const [height, width] = scores.shape;
let localMaximum = true;
const yStart = Math.max(heatmapY - localMaximumRadius, 0);
const yEnd = Math.min(heatmapY + localMaximumRadius + 1, height);
for (let yCurrent = yStart; yCurrent < yEnd; ++yCurrent) {
const xStart = Math.max(heatmapX - localMaximumRadius, 0);
const xEnd = Math.min(heatmapX + localMaximumRadius + 1, width);
for (let xCurrent = xStart; xCurrent < xEnd; ++xCurrent) {
if (scores.get(yCurrent, xCurrent, keypointId) > score) {
localMaximum = false;
break;
}
}
if (!localMaximum) {
break;
}
}
return localMaximum;
}
function buildPartWithScoreQueue(scoreThreshold, localMaximumRadius, scores) {
const [height, width, numKeypoints] = scores.shape;
const queue = new heapSort.MaxHeap(height * width * numKeypoints, ({score}) => score);
for (let heatmapY = 0; heatmapY < height; ++heatmapY) {
for (let heatmapX = 0; heatmapX < width; ++heatmapX) {
for (let keypointId = 0; keypointId < numKeypoints; ++keypointId) {
const score = scores.get(heatmapY, heatmapX, keypointId);
if (score < scoreThreshold)
continue;
if (scoreIsMaximumInLocalWindow(keypointId, score, heatmapY, heatmapX, localMaximumRadius, scores)) {
queue.enqueue({score, part: {heatmapY, heatmapX, id: keypointId}});
}
}
}
}
return queue;
}
exports.buildPartWithScoreQueue = buildPartWithScoreQueue;
});
// src/posenet/keypoints.js
var require_keypoints2 = __commonJS((exports) => {
exports.partNames = [
"nose",
"leftEye",
"rightEye",
"leftEar",
"rightEar",
"leftShoulder",
"rightShoulder",
"leftElbow",
"rightElbow",
"leftWrist",
"rightWrist",
"leftHip",
"rightHip",
"leftKnee",
"rightKnee",
"leftAnkle",
"rightAnkle"
];
exports.NUM_KEYPOINTS = exports.partNames.length;
exports.partIds = exports.partNames.reduce((result, jointName, i) => {
result[jointName] = i;
return result;
}, {});
const connectedPartNames = [
["leftHip", "leftShoulder"],
["leftElbow", "leftShoulder"],
["leftElbow", "leftWrist"],
["leftHip", "leftKnee"],
["leftKnee", "leftAnkle"],
["rightHip", "rightShoulder"],
["rightElbow", "rightShoulder"],
["rightElbow", "rightWrist"],
["rightHip", "rightKnee"],
["rightKnee", "rightAnkle"],
["leftShoulder", "rightShoulder"],
["leftHip", "rightHip"]
];
exports.poseChain = [
["nose", "leftEye"],
["leftEye", "leftEar"],
["nose", "rightEye"],
["rightEye", "rightEar"],
["nose", "leftShoulder"],
["leftShoulder", "leftElbow"],
["leftElbow", "leftWrist"],
["leftShoulder", "leftHip"],
["leftHip", "leftKnee"],
["leftKnee", "leftAnkle"],
["nose", "rightShoulder"],
["rightShoulder", "rightElbow"],
["rightElbow", "rightWrist"],
["rightShoulder", "rightHip"],
["rightHip", "rightKnee"],
["rightKnee", "rightAnkle"]
];
exports.connectedPartIndices = connectedPartNames.map(([jointNameA, jointNameB]) => [exports.partIds[jointNameA], exports.partIds[jointNameB]]);
exports.partChannels = [
"left_face",
"right_face",
"right_upper_leg_front",
"right_lower_leg_back",
"right_upper_leg_back",
"left_lower_leg_front",
"left_upper_leg_front",
"left_upper_leg_back",
"left_lower_leg_back",
"right_feet",
"right_lower_leg_front",
"left_feet",
"torso_front",
"torso_back",
"right_upper_arm_front",
"right_upper_arm_back",
"right_lower_arm_back",
"left_lower_arm_front",
"left_upper_arm_front",
"left_upper_arm_back",
"left_lower_arm_back",
"right_hand",
"right_lower_arm_front",
"left_hand"
];
});
// src/posenet/vectors.js
var require_vectors = __commonJS((exports) => {
const kpt = require_keypoints2();
function getOffsetPoint(y, x, keypoint, offsets) {
return {
y: offsets.get(y, x, keypoint),
x: offsets.get(y, x, keypoint + kpt.NUM_KEYPOINTS)
};
}
exports.getOffsetPoint = getOffsetPoint;
function getImageCoords(part, outputStride, offsets) {
const {heatmapY, heatmapX, id: keypoint} = part;
const {y, x} = getOffsetPoint(heatmapY, heatmapX, keypoint, offsets);
return {
x: part.heatmapX * outputStride + x,
y: part.heatmapY * outputStride + y
};
}
exports.getImageCoords = getImageCoords;
function fillArray(element, size) {
const result = new Array(size);
for (let i = 0; i < size; i++) {
result[i] = element;
}
return result;
}
exports.fillArray = fillArray;
function clamp(a, min, max) {
if (a < min)
return min;
if (a > max)
return max;
return a;
}
exports.clamp = clamp;
function squaredDistance(y1, x1, y2, x2) {
const dy = y2 - y1;
const dx = x2 - x1;
return dy * dy + dx * dx;
}
exports.squaredDistance = squaredDistance;
function addVectors(a, b) {
return {x: a.x + b.x, y: a.y + b.y};
}
exports.addVectors = addVectors;
function clampVector(a, min, max) {
return {y: clamp(a.y, min, max), x: clamp(a.x, min, max)};
}
exports.clampVector = clampVector;
});
// src/posenet/decodePose.js
var require_decodePose = __commonJS((exports) => {
const keypoints = require_keypoints2();
const vectors = require_vectors();
const parentChildrenTuples = keypoints.poseChain.map(([parentJoinName, childJoinName]) => [keypoints.partIds[parentJoinName], keypoints.partIds[childJoinName]]);
const parentToChildEdges = parentChildrenTuples.map(([, childJointId]) => childJointId);
const childToParentEdges = parentChildrenTuples.map(([parentJointId]) => parentJointId);
function getDisplacement(edgeId, point, displacements) {
const numEdges = displacements.shape[2] / 2;
return {
y: displacements.get(point.y, point.x, edgeId),
x: displacements.get(point.y, point.x, numEdges + edgeId)
};
}
function getStridedIndexNearPoint(point, outputStride, height, width) {
return {
y: vectors.clamp(Math.round(point.y / outputStride), 0, height - 1),
x: vectors.clamp(Math.round(point.x / outputStride), 0, width - 1)
};
}
function traverseToTargetKeypoint(edgeId, sourceKeypoint, targetKeypointId, scoresBuffer, offsets, outputStride, displacements, offsetRefineStep = 2) {
const [height, width] = scoresBuffer.shape;
const sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypoint.position, outputStride, height, width);
const displacement = getDisplacement(edgeId, sourceKeypointIndices, displacements);
const displacedPoint = vectors.addVectors(sourceKeypoint.position, displacement);
let targetKeypoint = displacedPoint;
for (let i = 0; i < offsetRefineStep; i++) {
const targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width);
const offsetPoint = vectors.getOffsetPoint(targetKeypointIndices.y, targetKeypointIndices.x, targetKeypointId, offsets);
targetKeypoint = vectors.addVectors({
x: targetKeypointIndices.x * outputStride,
y: targetKeypointIndices.y * outputStride
}, {x: offsetPoint.x, y: offsetPoint.y});
}
const targetKeyPointIndices = getStridedIndexNearPoint(targetKeypoint, outputStride, height, width);
const score = scoresBuffer.get(targetKeyPointIndices.y, targetKeyPointIndices.x, targetKeypointId);
return {position: targetKeypoint, part: keypoints.partNames[targetKeypointId], score};
}
function decodePose(root, scores, offsets, outputStride, displacementsFwd, displacementsBwd) {
const numParts = scores.shape[2];
const numEdges = parentToChildEdges.length;
const instanceKeypoints = new Array(numParts);
const {part: rootPart, score: rootScore} = root;
const rootPoint = vectors.getImageCoords(rootPart, outputStride, offsets);
instanceKeypoints[rootPart.id] = {
score: rootScore,
part: keypoints.partNames[rootPart.id],
position: rootPoint
};
for (let edge = numEdges - 1; edge >= 0; --edge) {
const sourceKeypointId = parentToChildEdges[edge];
const targetKeypointId = childToParentEdges[edge];
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsBwd);
}
}
for (let edge = 0; edge < numEdges; ++edge) {
const sourceKeypointId = childToParentEdges[edge];
const targetKeypointId = parentToChildEdges[edge];
if (instanceKeypoints[sourceKeypointId] && !instanceKeypoints[targetKeypointId]) {
instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint(edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, offsets, outputStride, displacementsFwd);
}
}
return instanceKeypoints;
}
exports.decodePose = decodePose;
});
// src/posenet/decodeMultiple.js
var require_decodeMultiple = __commonJS((exports) => {
const buildParts = require_buildParts();
const decodePose = require_decodePose();
const vectors = require_vectors();
function withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, {x, y}, keypointId) {
return poses.some(({keypoints}) => {
const correspondingKeypoint = keypoints[keypointId].position;
return vectors.squaredDistance(y, x, correspondingKeypoint.y, correspondingKeypoint.x) <= squaredNmsRadius;
});
}
function getInstanceScore(existingPoses, squaredNmsRadius, instanceKeypoints) {
const notOverlappedKeypointScores = instanceKeypoints.reduce((result, {position, score}, keypointId) => {
if (!withinNmsRadiusOfCorrespondingPoint(existingPoses, squaredNmsRadius, position, keypointId)) {
result += score;
}
return result;
}, 0);
return notOverlappedKeypointScores / instanceKeypoints.length;
}
const kLocalMaximumRadius = 1;
function decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, outputStride, maxPoseDetections, scoreThreshold = 0.5, nmsRadius = 20) {
const poses = [];
const queue = buildParts.buildPartWithScoreQueue(scoreThreshold, kLocalMaximumRadius, scoresBuffer);
const squaredNmsRadius = nmsRadius * nmsRadius;
while (poses.length < maxPoseDetections && !queue.empty()) {
const root = queue.dequeue();
const rootImageCoords = vectors.getImageCoords(root.part, outputStride, offsetsBuffer);
if (withinNmsRadiusOfCorrespondingPoint(poses, squaredNmsRadius, rootImageCoords, root.part.id))
continue;
const keypoints = decodePose.decodePose(root, scoresBuffer, offsetsBuffer, outputStride, displacementsFwdBuffer, displacementsBwdBuffer);
const score = getInstanceScore(poses, squaredNmsRadius, keypoints);
poses.push({keypoints, score});
}
return poses;
}
exports.decodeMultiplePoses = decodeMultiplePoses;
});
// src/posenet/util.js
var require_util3 = __commonJS((exports) => {
const kpt = require_keypoints2();
function eitherPointDoesntMeetConfidence(a, b, minConfidence) {
return a < minConfidence || b < minConfidence;
}
function getAdjacentKeyPoints(keypoints, minConfidence) {
return kpt.connectedPartIndices.reduce((result, [leftJoint, rightJoint]) => {
if (eitherPointDoesntMeetConfidence(keypoints[leftJoint].score, keypoints[rightJoint].score, minConfidence)) {
return result;
}
result.push([keypoints[leftJoint], keypoints[rightJoint]]);
return result;
}, []);
}
exports.getAdjacentKeyPoints = getAdjacentKeyPoints;
const {NEGATIVE_INFINITY, POSITIVE_INFINITY} = Number;
function getBoundingBox(keypoints) {
return keypoints.reduce(({maxX, maxY, minX, minY}, {position: {x, y}}) => ({
maxX: Math.max(maxX, x),
maxY: Math.max(maxY, y),
minX: Math.min(minX, x),
minY: Math.min(minY, y)
}), {
maxX: NEGATIVE_INFINITY,
maxY: NEGATIVE_INFINITY,
minX: POSITIVE_INFINITY,
minY: POSITIVE_INFINITY
});
}
exports.getBoundingBox = getBoundingBox;
function getBoundingBoxPoints(keypoints) {
const {minX, minY, maxX, maxY} = getBoundingBox(keypoints);
return [{x: minX, y: minY}, {x: maxX, y: minY}, {x: maxX, y: maxY}, {x: minX, y: maxY}];
}
exports.getBoundingBoxPoints = getBoundingBoxPoints;
async function toTensorBuffers3D(tensors) {
return Promise.all(tensors.map((tensor) => tensor.buffer()));
}
exports.toTensorBuffers3D = toTensorBuffers3D;
function scalePose(pose, scaleY, scaleX) {
return {
score: pose.score,
keypoints: pose.keypoints.map(({score, part, position}) => ({
score,
part,
position: {x: position.x * scaleX, y: position.y * scaleY}
}))
};
}
exports.scalePose = scalePose;
function resizeTo(image, [targetH, targetW]) {
const input = image.squeeze(0);
const resized = input.resizeBilinear([targetH, targetW]);
input.dispose();
return resized;
}
exports.resizeTo = resizeTo;
function scaleAndFlipPoses(poses, [height, width], [inputResolutionHeight, inputResolutionWidth]) {
const scaledPoses = poses.map((pose) => scalePose(pose, height / inputResolutionHeight, width / inputResolutionWidth));
return scaledPoses;
}
exports.scaleAndFlipPoses = scaleAndFlipPoses;
});
// src/posenet/modelPoseNet.js
var require_modelPoseNet = __commonJS((exports) => {
const tf = require_tf_node();
const modelMobileNet = require_modelMobileNet();
const decodeMultiple = require_decodeMultiple();
const util = require_util3();
class PoseNet {
constructor(net) {
this.baseModel = net;
}
async estimatePoses(input, config) {
const outputStride = config.outputStride;
const height = input.shape[1];
const width = input.shape[2];
const resized = util.resizeTo(input, [config.inputResolution, config.inputResolution]);
const {heatmapScores, offsets, displacementFwd, displacementBwd} = this.baseModel.predict(resized);
const allTensorBuffers = await util.toTensorBuffers3D([heatmapScores, offsets, displacementFwd, displacementBwd]);
const scoresBuffer = allTensorBuffers[0];
const offsetsBuffer = allTensorBuffers[1];
const displacementsFwdBuffer = allTensorBuffers[2];
const displacementsBwdBuffer = allTensorBuffers[3];
const poses = await decodeMultiple.decodeMultiplePoses(scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer, outputStride, config.maxDetections, config.scoreThreshold, config.nmsRadius);
const resultPoses = util.scaleAndFlipPoses(poses, [height, width], [config.inputResolution, config.inputResolution]);
heatmapScores.dispose();
offsets.dispose();
displacementFwd.dispose();
displacementBwd.dispose();
resized.dispose();
return resultPoses;
}
dispose() {
this.baseModel.dispose();
}
}
exports.PoseNet = PoseNet;
async function loadMobileNet(config) {
const graphModel = await tf.loadGraphModel(config.modelPath);
const mobilenet = new modelMobileNet.MobileNet(graphModel, config.outputStride);
return new PoseNet(mobilenet);
}
async function load(config) {
return loadMobileNet(config);
}
exports.load = load;
});
// src/posenet/posenet.js
var require_posenet = __commonJS((exports) => {
const modelMobileNet = require_modelMobileNet();
const modelPoseNet = require_modelPoseNet();
const decodeMultiple = require_decodeMultiple();
const keypoints = require_keypoints2();
const util = require_util3();
exports.load = modelPoseNet.load;
exports.PoseNet = modelPoseNet.PoseNet;
exports.MobileNet = modelMobileNet.MobileNet;
exports.decodeMultiplePoses = decodeMultiple.decodeMultiplePoses;
exports.partChannels = keypoints.partChannels;
exports.partIds = keypoints.partIds;
exports.partNames = keypoints.partNames;
exports.poseChain = keypoints.poseChain;
exports.getAdjacentKeyPoints = util.getAdjacentKeyPoints;
exports.getBoundingBox = util.getBoundingBox;
exports.getBoundingBoxPoints = util.getBoundingBoxPoints;
exports.scaleAndFlipPoses = util.scaleAndFlipPoses;
exports.scalePose = util.scalePose;
});
// src/handpose/box.js
var require_box2 = __commonJS((exports) => {
const tf = require_tf_node();
function getBoxSize(box) {
return [
Math.abs(box.endPoint[0] - box.startPoint[0]),
Math.abs(box.endPoint[1] - box.startPoint[1])
];
}
exports.getBoxSize = getBoxSize;
function getBoxCenter(box) {
return [
box.startPoint[0] + (box.endPoint[0] - box.startPoint[0]) / 2,
box.startPoint[1] + (box.endPoint[1] - box.startPoint[1]) / 2
];
}
exports.getBoxCenter = getBoxCenter;
function cutBoxFromImageAndResize(box, image, cropSize) {
const h = image.shape[1];
const w = image.shape[2];
const boxes = [[
box.startPoint[1] / h,
box.startPoint[0] / w,
box.endPoint[1] / h,
box.endPoint[0] / w
]];
return tf.image.cropAndResize(image, boxes, [0], cropSize);
}
exports.cutBoxFromImageAndResize = cutBoxFromImageAndResize;
function scaleBoxCoordinates(box, factor) {
const startPoint = [box.startPoint[0] * factor[0], box.startPoint[1] * factor[1]];
const endPoint = [box.endPoint[0] * factor[0], box.endPoint[1] * factor[1]];
const palmLandmarks = box.palmLandmarks.map((coord) => {
const scaledCoord = [coord[0] * factor[0], coord[1] * factor[1]];
return scaledCoord;
});
return {startPoint, endPoint, palmLandmarks};
}
exports.scaleBoxCoordinates = scaleBoxCoordinates;
function enlargeBox(box, factor = 1.5) {
const center = getBoxCenter(box);
const size = getBoxSize(box);
const newHalfSize = [factor * size[0] / 2, factor * size[1] / 2];
const startPoint = [center[0] - newHalfSize[0], center[1] - newHalfSize[1]];
const endPoint = [center[0] + newHalfSize[0], center[1] + newHalfSize[1]];
return {startPoint, endPoint, palmLandmarks: box.palmLandmarks};
}
exports.enlargeBox = enlargeBox;
function squarifyBox(box) {
const centers = getBoxCenter(box);
const size = getBoxSize(box);
const maxEdge = Math.max(...size);
const halfSize = maxEdge / 2;
const startPoint = [centers[0] - halfSize, centers[1] - halfSize];
const endPoint = [centers[0] + halfSize, centers[1] + halfSize];
return {startPoint, endPoint, palmLandmarks: box.palmLandmarks};
}
exports.squarifyBox = squarifyBox;
function shiftBox(box, shiftFactor) {
const boxSize = [
box.endPoint[0] - box.startPoint[0],
box.endPoint[1] - box.startPoint[1]
];
const shiftVector = [boxSize[0] * shiftFactor[0], boxSize[1] * shiftFactor[1]];
const startPoint = [box.startPoint[0] + shiftVector[0], box.startPoint[1] + shiftVector[1]];
const endPoint = [box.endPoint[0] + shiftVector[0], box.endPoint[1] + shiftVector[1]];
return {startPoint, endPoint, palmLandmarks: box.palmLandmarks};
}
exports.shiftBox = shiftBox;
});
// src/handpose/handdetector.js
var require_handdetector = __commonJS((exports) => {
const tf = require_tf_node();
const bounding = require_box2();
class HandDetector {
constructor(model, anchors, config) {
this.model = model;
this.width = config.inputSize;
this.height = config.inputSize;
this.anchors = anchors.map((anchor) => [anchor.x_center, anchor.y_center]);
this.anchorsTensor = tf.tensor2d(this.anchors);
this.inputSizeTensor = tf.tensor1d([config.inputSize, config.inputSize]);
this.doubleInputSizeTensor = tf.tensor1d([config.inputSize * 2, config.inputSize * 2]);
}
normalizeBoxes(boxes) {
return tf.tidy(() => {
const boxOffsets = tf.slice(boxes, [0, 0], [-1, 2]);
const boxSizes = tf.slice(boxes, [0, 2], [-1, 2]);
const boxCenterPoints = tf.add(tf.div(boxOffsets, this.inputSizeTensor), this.anchorsTensor);
const halfBoxSizes = tf.div(boxSizes, this.doubleInputSizeTensor);
const startPoints = tf.mul(tf.sub(boxCenterPoints, halfBoxSizes), this.inputSizeTensor);
const endPoints = tf.mul(tf.add(boxCenterPoints, halfBoxSizes), this.inputSizeTensor);
return tf.concat2d([startPoints, endPoints], 1);
});
}
normalizeLandmarks(rawPalmLandmarks, index) {
return tf.tidy(() => {
const landmarks = tf.add(tf.div(rawPalmLandmarks.reshape([-1, 7, 2]), this.inputSizeTensor), this.anchors[index]);
return tf.mul(landmarks, this.inputSizeTensor);
});
}
async getBoundingBoxes(input) {
const batchedPrediction = this.model.predict(input);
const prediction = batchedPrediction.squeeze();
const scores = tf.tidy(() => tf.sigmoid(tf.slice(prediction, [0, 0], [-1, 1])).squeeze());
const rawBoxes = tf.slice(prediction, [0, 1], [-1, 4]);
const boxes = this.normalizeBoxes(rawBoxes);
const boxesWithHandsTensor = await tf.image.nonMaxSuppressionAsync(boxes, scores, this.maxHands, this.iouThreshold, this.scoreThreshold);
const boxesWithHands = await boxesWithHandsTensor.array();
const toDispose = [batchedPrediction, boxesWithHandsTensor, prediction, boxes, rawBoxes, scores];
const detectedHands = tf.tidy(() => {
const detectedBoxes = [];
for (const i in boxesWithHands) {
const boxIndex = boxesWithHands[i];
const matchingBox = tf.slice(boxes, [boxIndex, 0], [1, -1]);
const rawPalmLandmarks = tf.slice(prediction, [boxIndex, 5], [1, 14]);
const palmLandmarks = tf.tidy(() => this.normalizeLandmarks(rawPalmLandmarks, boxIndex).reshape([-1, 2]));
detectedBoxes.push({boxes: matchingBox, palmLandmarks});
}
return detectedBoxes;
});
toDispose.forEach((tensor) => tensor.dispose());
return detectedHands;
}
async estimateHandBounds(input, config) {
this.iouThreshold = config.iouThreshold;
this.scoreThreshold = config.scoreThreshold;
this.maxHands = config.maxHands;
const resized = input.resizeBilinear([this.width, this.height]);
const divided = resized.div(255);
const normalized = divided.sub(0.5);
const image = normalized.mul(2);
resized.dispose();
divided.dispose();
normalized.dispose();
const predictions = await this.getBoundingBoxes(image);
image.dispose();
if (!predictions || predictions.length === 0)
return null;
const hands = [];
for (const i in predictions) {
const prediction = predictions[i];
const boundingBoxes = await prediction.boxes.array();
const startPoint = boundingBoxes[0].slice(0, 2);
const endPoint = boundingBoxes[0].slice(2, 4);
const palmLandmarks = await prediction.palmLandmarks.array();
prediction.boxes.dispose();
prediction.palmLandmarks.dispose();
hands.push(bounding.scaleBoxCoordinates({startPoint, endPoint, palmLandmarks}, [input.shape[2] / this.width, input.shape[1] / this.height]));
}
return hands;
}
}
exports.HandDetector = HandDetector;
});
// src/handpose/keypoints.js
var require_keypoints3 = __commonJS((exports) => {
exports.MESH_ANNOTATIONS = {
thumb: [1, 2, 3, 4],
indexFinger: [5, 6, 7, 8],
middleFinger: [9, 10, 11, 12],
ringFinger: [13, 14, 15, 16],
pinky: [17, 18, 19, 20],
palmBase: [0]
};
});
// src/handpose/util.js
var require_util4 = __commonJS((exports) => {
function normalizeRadians(angle) {
return angle - 2 * Math.PI * Math.floor((angle + Math.PI) / (2 * Math.PI));
}
exports.normalizeRadians = normalizeRadians;
function computeRotation(point1, point2) {
const radians = Math.PI / 2 - Math.atan2(-(point2[1] - point1[1]), point2[0] - point1[0]);
return normalizeRadians(radians);
}
exports.computeRotation = computeRotation;
const buildTranslationMatrix = (x, y) => [[1, 0, x], [0, 1, y], [0, 0, 1]];
function dot(v1, v2) {
let product = 0;
for (let i = 0; i < v1.length; i++) {
product += v1[i] * v2[i];
}
return product;
}
exports.dot = dot;
function getColumnFrom2DArr(arr, columnIndex) {
const column = [];
for (let i = 0; i < arr.length; i++) {
column.push(arr[i][columnIndex]);
}
return column;
}
exports.getColumnFrom2DArr = getColumnFrom2DArr;
function multiplyTransformMatrices(mat1, mat2) {
const product = [];
const size = mat1.length;
for (let row = 0; row < size; row++) {
product.push([]);
for (let col = 0; col < size; col++) {
product[row].push(dot(mat1[row], getColumnFrom2DArr(mat2, col)));
}
}
return product;
}
function buildRotationMatrix(rotation, center) {
const cosA = Math.cos(rotation);
const sinA = Math.sin(rotation);
const rotationMatrix = [[cosA, -sinA, 0], [sinA, cosA, 0], [0, 0, 1]];
const translationMatrix = buildTranslationMatrix(center[0], center[1]);
const translationTimesRotation = multiplyTransformMatrices(translationMatrix, rotationMatrix);
const negativeTranslationMatrix = buildTranslationMatrix(-center[0], -center[1]);
return multiplyTransformMatrices(translationTimesRotation, negativeTranslationMatrix);
}
exports.buildRotationMatrix = buildRotationMatrix;
function invertTransformMatrix(matrix) {
const rotationComponent = [[matrix[0][0], matrix[1][0]], [matrix[0][1], matrix[1][1]]];
const translationComponent = [matrix[0][2], matrix[1][2]];
const invertedTranslation = [
-dot(rotationComponent[0], translationComponent),
-dot(rotationComponent[1], translationComponent)
];
return [
rotationComponent[0].concat(invertedTranslation[0]),
rotationComponent[1].concat(invertedTranslation[1]),
[0, 0, 1]
];
}
exports.invertTransformMatrix = invertTransformMatrix;
function rotatePoint(homogeneousCoordinate, rotationMatrix) {
return [
dot(homogeneousCoordinate, rotationMatrix[0]),
dot(homogeneousCoordinate, rotationMatrix[1])
];
}
exports.rotatePoint = rotatePoint;
});
// src/handpose/pipeline.js
var require_pipeline2 = __commonJS((exports) => {
const tf = require_tf_node();
const bounding = require_box2();
const util = require_util4();
const UPDATE_REGION_OF_INTEREST_IOU_THRESHOLD = 0.8;
const PALM_BOX_SHIFT_VECTOR = [0, -0.4];
const HAND_BOX_SHIFT_VECTOR = [0, -0.1];
const HAND_BOX_ENLARGE_FACTOR = 1.65;
const PALM_LANDMARK_IDS = [0, 5, 9, 13, 17, 1, 2];
const PALM_LANDMARKS_INDEX_OF_PALM_BASE = 0;
const PALM_LANDMARKS_INDEX_OF_MIDDLE_FINGER_BASE = 2;
class HandPipeline {
constructor(boundingBoxDetector, meshDetector, config) {
this.regionsOfInterest = [];
this.runsWithoutHandDetector = 0;
this.boundingBoxDetector = boundingBoxDetector;
this.meshDetector = meshDetector;
this.meshWidth = config.inputSize;
this.meshHeight = config.inputSize;
this.enlargeFactor = config.enlargeFactor;
}
getBoxForPalmLandmarks(palmLandmarks, rotationMatrix) {
const rotatedPalmLandmarks = palmLandmarks.map((coord) => {
const homogeneousCoordinate = [...coord, 1];
return util.rotatePoint(homogeneousCoordinate, rotationMatrix);
});
const boxAroundPalm = this.calculateLandmarksBoundingBox(rotatedPalmLandmarks);
return bounding.enlargeBox(bounding.squarifyBox(bounding.shiftBox(boxAroundPalm, PALM_BOX_SHIFT_VECTOR)), this.enlargeFactor);
}
getBoxForHandLandmarks(landmarks) {
const boundingBox = this.calculateLandmarksBoundingBox(landmarks);
const boxAroundHand = bounding.enlargeBox(bounding.squarifyBox(bounding.shiftBox(boundingBox, HAND_BOX_SHIFT_VECTOR)), HAND_BOX_ENLARGE_FACTOR);
const palmLandmarks = [];
for (let i = 0; i < PALM_LANDMARK_IDS.length; i++) {
palmLandmarks.push(landmarks[PALM_LANDMARK_IDS[i]].slice(0, 2));
}
boxAroundHand.palmLandmarks = palmLandmarks;
return boxAroundHand;
}
transformRawCoords(rawCoords, box, angle, rotationMatrix) {
const boxSize = bounding.getBoxSize(box);
const scaleFactor = [boxSize[0] / this.meshWidth, boxSize[1] / this.meshHeight];
const coordsScaled = rawCoords.map((coord) => [
scaleFactor[0] * (coord[0] - this.meshWidth / 2),
scaleFactor[1] * (coord[1] - this.meshHeight / 2),
coord[2]
]);
const coordsRotationMatrix = util.buildRotationMatrix(angle, [0, 0]);
const coordsRotated = coordsScaled.map((coord) => {
const rotated = util.rotatePoint(coord, coordsRotationMatrix);
return [...rotated, coord[2]];
});
const inverseRotationMatrix = util.invertTransformMatrix(rotationMatrix);
const boxCenter = [...bounding.getBoxCenter(box), 1];
const originalBoxCenter = [
util.dot(boxCenter, inverseRotationMatrix[0]),
util.dot(boxCenter, inverseRotationMatrix[1])
];
return coordsRotated.map((coord) => [
coord[0] + originalBoxCenter[0],
coord[1] + originalBoxCenter[1],
coord[2]
]);
}
async estimateHands(image, config) {
this.skipFrames = config.skipFrames;
this.detectionConfidence = config.minConfidence;
this.maxHands = config.maxHands;
this.runsWithoutHandDetector++;
const useFreshBox = this.shouldUpdateRegionsOfInterest();
if (useFreshBox === true) {
const boundingBoxPredictions = await this.boundingBoxDetector.estimateHandBounds(image, config);
this.regionsOfInterest = [];
for (const i in boundingBoxPredictions) {
this.updateRegionsOfInterest(boundingBoxPredictions[i], true, i);
}
this.runsWithoutHandDetector = 0;
}
const hands = [];
if (!this.regionsOfInterest)
return hands;
for (const i in this.regionsOfInterest) {
const currentBox = this.regionsOfInterest[i][0];
if (!currentBox)
return hands;
const angle = util.computeRotation(currentBox.palmLandmarks[PALM_LANDMARKS_INDEX_OF_PALM_BASE], currentBox.palmLandmarks[PALM_LANDMARKS_INDEX_OF_MIDDLE_FINGER_BASE]);
const palmCenter = bounding.getBoxCenter(currentBox);
const palmCenterNormalized = [palmCenter[0] / image.shape[2], palmCenter[1] / image.shape[1]];
const rotatedImage = tf.image.rotateWithOffset(image, angle, 0, palmCenterNormalized);
const rotationMatrix = util.buildRotationMatrix(-angle, palmCenter);
const box = useFreshBox ? this.getBoxForPalmLandmarks(currentBox.palmLandmarks, rotationMatrix) : currentBox;
const croppedInput = bounding.cutBoxFromImageAndResize(box, rotatedImage, [this.meshWidth, this.meshHeight]);
const handImage = croppedInput.div(255);
croppedInput.dispose();
rotatedImage.dispose();
const prediction = this.meshDetector.predict(handImage);
const [flag, keypoints] = prediction;
handImage.dispose();
const flagValue = flag.dataSync()[0];
flag.dispose();
if (flagValue < config.minConfidence) {
keypoints.dispose();
this.regionsOfInterest[i] = [];
return hands;
}
const keypointsReshaped = tf.reshape(keypoints, [-1, 3]);
const rawCoords = await keypointsReshaped.array();
keypoints.dispose();
keypointsReshaped.dispose();
const coords = this.transformRawCoords(rawCoords, box, angle, rotationMatrix);
const nextBoundingBox = this.getBoxForHandLandmarks(coords);
this.updateRegionsOfInterest(nextBoundingBox, false, i);
const result = {
landmarks: coords,
confidence: flagValue,
box: {
topLeft: nextBoundingBox.startPoint,
bottomRight: nextBoundingBox.endPoint
}
};
hands.push(result);
}
return hands;
}
calculateLandmarksBoundingBox(landmarks) {
const xs = landmarks.map((d) => d[0]);
const ys = landmarks.map((d) => d[1]);
const startPoint = [Math.min(...xs), Math.min(...ys)];
const endPoint = [Math.max(...xs), Math.max(...ys)];
return {startPoint, endPoint};
}
updateRegionsOfInterest(box, forceUpdate, index) {
if (forceUpdate) {
this.regionsOfInterest[index] = [box];
} else {
const previousBox = this.regionsOfInterest[index][0];
let iou = 0;
if (previousBox != null && previousBox.startPoint != null) {
const [boxStartX, boxStartY] = box.startPoint;
const [boxEndX, boxEndY] = box.endPoint;
const [previousBoxStartX, previousBoxStartY] = previousBox.startPoint;
const [previousBoxEndX, previousBoxEndY] = previousBox.endPoint;
const xStartMax = Math.max(boxStartX, previousBoxStartX);
const yStartMax = Math.max(boxStartY, previousBoxStartY);
const xEndMin = Math.min(boxEndX, previousBoxEndX);
const yEndMin = Math.min(boxEndY, previousBoxEndY);
const intersection = (xEndMin - xStartMax) * (yEndMin - yStartMax);
const boxArea = (boxEndX - boxStartX) * (boxEndY - boxStartY);
const previousBoxArea = (previousBoxEndX - previousBoxStartX) * (previousBoxEndY - boxStartY);
iou = intersection / (boxArea + previousBoxArea - intersection);
}
this.regionsOfInterest[index][0] = iou > UPDATE_REGION_OF_INTEREST_IOU_THRESHOLD ? previousBox : box;
}
}
shouldUpdateRegionsOfInterest() {
return !this.regionsOfInterest || this.regionsOfInterest.length === 0 || this.runsWithoutHandDetector >= this.skipFrames;
}
}
exports.HandPipeline = HandPipeline;
});
// src/handpose/handpose.js
var require_handpose = __commonJS((exports) => {
const tf = require_tf_node();
const hand = require_handdetector();
const keypoints = require_keypoints3();
const pipe = require_pipeline2();
class HandPose {
constructor(pipeline) {
this.pipeline = pipeline;
}
async estimateHands(input, config) {
this.skipFrames = config.skipFrames;
this.detectionConfidence = config.minConfidence;
this.maxHands = config.maxHands;
const predictions = await this.pipeline.estimateHands(input, config);
const hands = [];
if (!predictions)
return hands;
for (const prediction of predictions) {
if (!prediction)
return [];
const annotations = {};
for (const key of Object.keys(keypoints.MESH_ANNOTATIONS)) {
annotations[key] = keypoints.MESH_ANNOTATIONS[key].map((index) => prediction.landmarks[index]);
}
hands.push({
confidence: prediction.confidence || 0,
box: prediction.box ? [prediction.box.topLeft[0], prediction.box.topLeft[1], prediction.box.bottomRight[0] - prediction.box.topLeft[0], prediction.box.bottomRight[1] - prediction.box.topLeft[1]] : 0,
landmarks: prediction.landmarks,
annotations
});
}
return hands;
}
}
exports.HandPose = HandPose;
async function loadAnchors(url) {
if (tf.env().features.IS_NODE) {
const fs = require("fs");
const data = await fs.readFileSync(url.replace("file://", ""));
return JSON.parse(data);
}
return tf.util.fetch(url).then((d) => d.json());
}
async function load(config) {
const [anchors, handDetectorModel, handPoseModel] = await Promise.all([
loadAnchors(config.detector.anchors),
tf.loadGraphModel(config.detector.modelPath, {fromTFHub: config.detector.modelPath.includes("tfhub.dev")}),
tf.loadGraphModel(config.skeleton.modelPath, {fromTFHub: config.skeleton.modelPath.includes("tfhub.dev")})
]);
const detector = new hand.HandDetector(handDetectorModel, anchors, config);
const pipeline = new pipe.HandPipeline(detector, handPoseModel, config);
const handpose = new HandPose(pipeline);
return handpose;
}
exports.load = load;
});
// src/imagefx.js
var require_imagefx = __commonJS((exports) => {
const WebGLProgram = function(gl, vertexSource, fragmentSource) {
const _collect = function(source, prefix, collection) {
const r = new RegExp("\\b" + prefix + " \\w+ (\\w+)", "ig");
source.replace(r, (match, name) => {
collection[name] = 0;
return match;
});
};
const _compile = function(gl2, source, type) {
const shader = gl2.createShader(type);
gl2.shaderSource(shader, source);
gl2.compileShader(shader);
if (!gl2.getShaderParameter(shader, gl2.COMPILE_STATUS)) {
throw new Error("Filter: GL compile failed", gl2.getShaderInfoLog(shader));
}
return shader;
};
this.uniform = {};
this.attribute = {};
const _vsh = _compile(gl, vertexSource, gl.VERTEX_SHADER);
const _fsh = _compile(gl, fragmentSource, gl.FRAGMENT_SHADER);
this.id = gl.createProgram();
gl.attachShader(this.id, _vsh);
gl.attachShader(this.id, _fsh);
gl.linkProgram(this.id);
if (!gl.getProgramParameter(this.id, gl.LINK_STATUS)) {
throw new Error("Filter: GL link failed", gl.getProgramInfoLog(this.id));
}
gl.useProgram(this.id);
_collect(vertexSource, "attribute", this.attribute);
for (const a in this.attribute) {
this.attribute[a] = gl.getAttribLocation(this.id, a);
}
_collect(vertexSource, "uniform", this.uniform);
_collect(fragmentSource, "uniform", this.uniform);
for (const u in this.uniform) {
this.uniform[u] = gl.getUniformLocation(this.id, u);
}
};
const WebGLImageFilter = function(params) {
if (!params)
params = {};
let _drawCount = 0;
let _sourceTexture = null;
let _lastInChain = false;
let _currentFramebufferIndex = -1;
let _tempFramebuffers = [null, null];
let _filterChain = [];
let _width = -1;
let _height = -1;
let _vertexBuffer = null;
let _currentProgram = null;
const _canvas = params.canvas || document.createElement("canvas");
const _shaderProgramCache = {};
const gl = _canvas.getContext("webgl") || _canvas.getContext("experimental-webgl");
if (!gl)
throw new Error("Filter: getContext() failed");
this.addFilter = function(name) {
const args = Array.prototype.slice.call(arguments, 1);
const filter = _filter[name];
_filterChain.push({func: filter, args});
};
this.reset = function() {
_filterChain = [];
};
this.apply = function(image) {
_resize(image.width, image.height);
_drawCount = 0;
if (!_sourceTexture)
_sourceTexture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, _sourceTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, image);
if (_filterChain.length === 0) {
const program = _compileShader(SHADER.FRAGMENT_IDENTITY);
_draw();
return _canvas;
}
for (let i = 0; i < _filterChain.length; i++) {
_lastInChain = i === _filterChain.length - 1;
const f = _filterChain[i];
f.func.apply(this, f.args || []);
}
return _canvas;
};
const _resize = function(width, height) {
if (width === _width && height === _height) {
return;
}
_canvas.width = _width = width;
_canvas.height = _height = height;
if (!_vertexBuffer) {
const vertices = new Float32Array([
-1,
-1,
0,
1,
1,
-1,
1,
1,
-1,
1,
0,
0,
-1,
1,
0,
0,
1,
-1,
1,
1,
1,
1,
1,
0
]);
_vertexBuffer = gl.createBuffer(), gl.bindBuffer(gl.ARRAY_BUFFER, _vertexBuffer);
gl.bufferData(gl.ARRAY_BUFFER, vertices, gl.STATIC_DRAW);
gl.pixelStorei(gl.UNPACK_PREMULTIPLY_ALPHA_WEBGL, true);
}
gl.viewport(0, 0, _width, _height);
_tempFramebuffers = [null, null];
};
const _getTempFramebuffer = function(index) {
_tempFramebuffers[index] = _tempFramebuffers[index] || _createFramebufferTexture(_width, _height);
return _tempFramebuffers[index];
};
const _createFramebufferTexture = function(width, height) {
const fbo = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, fbo);
const renderbuffer = gl.createRenderbuffer();
gl.bindRenderbuffer(gl.RENDERBUFFER, renderbuffer);
const texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, width, height, 0, gl.RGBA, gl.UNSIGNED_BYTE, null);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
gl.bindTexture(gl.TEXTURE_2D, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
return {fbo, texture};
};
const _draw = function(flags) {
let source = null;
let target = null;
let flipY = false;
if (_drawCount === 0) {
source = _sourceTexture;
} else {
source = _getTempFramebuffer(_currentFramebufferIndex).texture;
}
_drawCount++;
if (_lastInChain && !(flags & DRAW.INTERMEDIATE)) {
target = null;
flipY = _drawCount % 2 === 0;
} else {
_currentFramebufferIndex = (_currentFramebufferIndex + 1) % 2;
target = _getTempFramebuffer(_currentFramebufferIndex).fbo;
}
gl.bindTexture(gl.TEXTURE_2D, source);
gl.bindFramebuffer(gl.FRAMEBUFFER, target);
gl.uniform1f(_currentProgram.uniform.flipY, flipY ? -1 : 1);
gl.drawArrays(gl.TRIANGLES, 0, 6);
};
const _compileShader = function(fragmentSource) {
if (_shaderProgramCache[fragmentSource]) {
_currentProgram = _shaderProgramCache[fragmentSource];
gl.useProgram(_currentProgram.id);
return _currentProgram;
}
_currentProgram = new WebGLProgram(gl, SHADER.VERTEX_IDENTITY, fragmentSource);
const floatSize = Float32Array.BYTES_PER_ELEMENT;
const vertSize = 4 * floatSize;
gl.enableVertexAttribArray(_currentProgram.attribute.pos);
gl.vertexAttribPointer(_currentProgram.attribute.pos, 2, gl.FLOAT, false, vertSize, 0 * floatSize);
gl.enableVertexAttribArray(_currentProgram.attribute.uv);
gl.vertexAttribPointer(_currentProgram.attribute.uv, 2, gl.FLOAT, false, vertSize, 2 * floatSize);
_shaderProgramCache[fragmentSource] = _currentProgram;
return _currentProgram;
};
let DRAW = {INTERMEDIATE: 1};
let SHADER = {};
SHADER.VERTEX_IDENTITY = [
"precision highp float;",
"attribute vec2 pos;",
"attribute vec2 uv;",
"varying vec2 vUv;",
"uniform float flipY;",
"void main(void) {",
"vUv = uv;",
"gl_Position = vec4(pos.x, pos.y*flipY, 0.0, 1.);",
"}"
].join("\n");
SHADER.FRAGMENT_IDENTITY = [
"precision highp float;",
"varying vec2 vUv;",
"uniform sampler2D texture;",
"void main(void) {",
"gl_FragColor = texture2D(texture, vUv);",
"}"
].join("\n");
let _filter = {};
_filter.colorMatrix = function(matrix) {
const m = new Float32Array(matrix);
m[4] /= 255;
m[9] /= 255;
m[14] /= 255;
m[19] /= 255;
const shader = m[18] === 1 && m[3] === 0 && m[8] === 0 && m[13] === 0 && m[15] === 0 && m[16] === 0 && m[17] === 0 && m[19] === 0 ? _filter.colorMatrix.SHADER.WITHOUT_ALPHA : _filter.colorMatrix.SHADER.WITH_ALPHA;
const program = _compileShader(shader);
gl.uniform1fv(program.uniform.m, m);
_draw();
};
_filter.colorMatrix.SHADER = {};
_filter.colorMatrix.SHADER.WITH_ALPHA = [
"precision highp float;",
"varying vec2 vUv;",
"uniform sampler2D texture;",
"uniform float m[20];",
"void main(void) {",
"vec4 c = texture2D(texture, vUv);",
"gl_FragColor.r = m[0] * c.r + m[1] * c.g + m[2] * c.b + m[3] * c.a + m[4];",
"gl_FragColor.g = m[5] * c.r + m[6] * c.g + m[7] * c.b + m[8] * c.a + m[9];",
"gl_FragColor.b = m[10] * c.r + m[11] * c.g + m[12] * c.b + m[13] * c.a + m[14];",
"gl_FragColor.a = m[15] * c.r + m[16] * c.g + m[17] * c.b + m[18] * c.a + m[19];",
"}"
].join("\n");
_filter.colorMatrix.SHADER.WITHOUT_ALPHA = [
"precision highp float;",
"varying vec2 vUv;",
"uniform sampler2D texture;",
"uniform float m[20];",
"void main(void) {",
"vec4 c = texture2D(texture, vUv);",
"gl_FragColor.r = m[0] * c.r + m[1] * c.g + m[2] * c.b + m[4];",
"gl_FragColor.g = m[5] * c.r + m[6] * c.g + m[7] * c.b + m[9];",
"gl_FragColor.b = m[10] * c.r + m[11] * c.g + m[12] * c.b + m[14];",
"gl_FragColor.a = c.a;",
"}"
].join("\n");
_filter.brightness = function(brightness) {
const b = (brightness || 0) + 1;
_filter.colorMatrix([
b,
0,
0,
0,
0,
0,
b,
0,
0,
0,
0,
0,
b,
0,
0,
0,
0,
0,
1,
0
]);
};
_filter.saturation = function(amount) {
const x = (amount || 0) * 2 / 3 + 1;
const y = (x - 1) * -0.5;
_filter.colorMatrix([
x,
y,
y,
0,
0,
y,
x,
y,
0,
0,
y,
y,
x,
0,
0,
0,
0,
0,
1,
0
]);
};
_filter.desaturate = function() {
_filter.saturation(-1);
};
_filter.contrast = function(amount) {
const v = (amount || 0) + 1;
const o = -128 * (v - 1);
_filter.colorMatrix([
v,
0,
0,
0,
o,
0,
v,
0,
0,
o,
0,
0,
v,
0,
o,
0,
0,
0,
1,
0
]);
};
_filter.negative = function() {
_filter.contrast(-2);
};
_filter.hue = function(rotation) {
rotation = (rotation || 0) / 180 * Math.PI;
const cos = Math.cos(rotation);
const sin = Math.sin(rotation);
const lumR = 0.213;
const lumG = 0.715;
const lumB = 0.072;
_filter.colorMatrix([
lumR + cos * (1 - lumR) + sin * -lumR,
lumG + cos * -lumG + sin * -lumG,
lumB + cos * -lumB + sin * (1 - lumB),
0,
0,
lumR + cos * -lumR + sin * 0.143,
lumG + cos * (1 - lumG) + sin * 0.14,
lumB + cos * -lumB + sin * -0.283,
0,
0,
lumR + cos * -lumR + sin * -(1 - lumR),
lumG + cos * -lumG + sin * lumG,
lumB + cos * (1 - lumB) + sin * lumB,
0,
0,
0,
0,
0,
1,
0
]);
};
_filter.desaturateLuminance = function() {
_filter.colorMatrix([
0.2764723,
0.929708,
0.0938197,
0,
-37.1,
0.2764723,
0.929708,
0.0938197,
0,
-37.1,
0.2764723,
0.929708,
0.0938197,
0,
-37.1,
0,
0,
0,
1,
0
]);
};
_filter.sepia = function() {
_filter.colorMatrix([
0.393,
0.7689999,
0.18899999,
0,
0,
0.349,
0.6859999,
0.16799999,
0,
0,
0.272,
0.5339999,
0.13099999,
0,
0,
0,
0,
0,
1,
0
]);
};
_filter.brownie = function() {
_filter.colorMatrix([
0.5997023498159715,
0.34553243048391263,
-0.2708298674538042,
0,
47.43192855600873,
-0.037703249837783157,
0.8609577587992641,
0.15059552388459913,
0,
-36.96841498319127,
0.24113635128153335,
-0.07441037908422492,
0.44972182064877153,
0,
-7.562075277591283,
0,
0,
0,
1,
0
]);
};
_filter.vintagePinhole = function() {
_filter.colorMatrix([
0.6279345635605994,
0.3202183420819367,
-0.03965408211312453,
0,
9.651285835294123,
0.02578397704808868,
0.6441188644374771,
0.03259127616149294,
0,
7.462829176470591,
0.0466055556782719,
-0.0851232987247891,
0.5241648018700465,
0,
5.159190588235296,
0,
0,
0,
1,
0
]);
};
_filter.kodachrome = function() {
_filter.colorMatrix([
1.1285582396593525,
-0.3967382283601348,
-0.03992559172921793,
0,
63.72958762196502,
-0.16404339962244616,
1.0835251566291304,
-0.05498805115633132,
0,
24.732407896706203,
-0.16786010706155763,
-0.5603416277695248,
1.6014850761964943,
0,
35.62982807460946,
0,
0,
0,
1,
0
]);
};
_filter.technicolor = function() {
_filter.colorMatrix([
1.9125277891456083,
-0.8545344976951645,
-0.09155508482755585,
0,
11.793603434377337,
-0.3087833385928097,
1.7658908555458428,
-0.10601743074722245,
0,
-70.35205161461398,
-0.231103377548616,
-0.7501899197440212,
1.847597816108189,
0,
30.950940869491138,
0,
0,
0,
1,
0
]);
};
_filter.polaroid = function() {
_filter.colorMatrix([
1.438,
-0.062,
-0.062,
0,
0,
-0.122,
1.378,
-0.122,
0,
0,
-0.016,
-0.016,
1.483,
0,
0,
0,
0,
0,
1,
0
]);
};
_filter.shiftToBGR = function() {
_filter.colorMatrix([
0,
0,
1,
0,
0,
0,
1,
0,
0,
0,
1,
0,
0,
0,
0,
0,
0,
0,
1,
0
]);
};
_filter.convolution = function(matrix) {
const m = new Float32Array(matrix);
const pixelSizeX = 1 / _width;
const pixelSizeY = 1 / _height;
const program = _compileShader(_filter.convolution.SHADER);
gl.uniform1fv(program.uniform.m, m);
gl.uniform2f(program.uniform.px, pixelSizeX, pixelSizeY);
_draw();
};
_filter.convolution.SHADER = [
"precision highp float;",
"varying vec2 vUv;",
"uniform sampler2D texture;",
"uniform vec2 px;",
"uniform float m[9];",
"void main(void) {",
"vec4 c11 = texture2D(texture, vUv - px);",
"vec4 c12 = texture2D(texture, vec2(vUv.x, vUv.y - px.y));",
"vec4 c13 = texture2D(texture, vec2(vUv.x + px.x, vUv.y - px.y));",
"vec4 c21 = texture2D(texture, vec2(vUv.x - px.x, vUv.y) );",
"vec4 c22 = texture2D(texture, vUv);",
"vec4 c23 = texture2D(texture, vec2(vUv.x + px.x, vUv.y) );",
"vec4 c31 = texture2D(texture, vec2(vUv.x - px.x, vUv.y + px.y) );",
"vec4 c32 = texture2D(texture, vec2(vUv.x, vUv.y + px.y) );",
"vec4 c33 = texture2D(texture, vUv + px );",
"gl_FragColor = ",
"c11 * m[0] + c12 * m[1] + c22 * m[2] +",
"c21 * m[3] + c22 * m[4] + c23 * m[5] +",
"c31 * m[6] + c32 * m[7] + c33 * m[8];",
"gl_FragColor.a = c22.a;",
"}"
].join("\n");
_filter.detectEdges = function() {
_filter.convolution.call(this, [
0,
1,
0,
1,
-4,
1,
0,
1,
0
]);
};
_filter.sobelX = function() {
_filter.convolution.call(this, [
-1,
0,
1,
-2,
0,
2,
-1,
0,
1
]);
};
_filter.sobelY = function() {
_filter.convolution.call(this, [
-1,
-2,
-1,
0,
0,
0,
1,
2,
1
]);
};
_filter.sharpen = function(amount) {
const a = amount || 1;
_filter.convolution.call(this, [
0,
-1 * a,
0,
-1 * a,
1 + 4 * a,
-1 * a,
0,
-1 * a,
0
]);
};
_filter.emboss = function(size) {
const s = size || 1;
_filter.convolution.call(this, [
-2 * s,
-1 * s,
0,
-1 * s,
1,
1 * s,
0,
1 * s,
2 * s
]);
};
_filter.blur = function(size) {
const blurSizeX = size / 7 / _width;
const blurSizeY = size / 7 / _height;
const program = _compileShader(_filter.blur.SHADER);
gl.uniform2f(program.uniform.px, 0, blurSizeY);
_draw(DRAW.INTERMEDIATE);
gl.uniform2f(program.uniform.px, blurSizeX, 0);
_draw();
};
_filter.blur.SHADER = [
"precision highp float;",
"varying vec2 vUv;",
"uniform sampler2D texture;",
"uniform vec2 px;",
"void main(void) {",
"gl_FragColor = vec4(0.0);",
"gl_FragColor += texture2D(texture, vUv + vec2(-7.0*px.x, -7.0*px.y))*0.0044299121055113265;",
"gl_FragColor += texture2D(texture, vUv + vec2(-6.0*px.x, -6.0*px.y))*0.00895781211794;",
"gl_FragColor += texture2D(texture, vUv + vec2(-5.0*px.x, -5.0*px.y))*0.0215963866053;",
"gl_FragColor += texture2D(texture, vUv + vec2(-4.0*px.x, -4.0*px.y))*0.0443683338718;",
"gl_FragColor += texture2D(texture, vUv + vec2(-3.0*px.x, -3.0*px.y))*0.0776744219933;",
"gl_FragColor += texture2D(texture, vUv + vec2(-2.0*px.x, -2.0*px.y))*0.115876621105;",
"gl_FragColor += texture2D(texture, vUv + vec2(-1.0*px.x, -1.0*px.y))*0.147308056121;",
"gl_FragColor += texture2D(texture, vUv )*0.159576912161;",
"gl_FragColor += texture2D(texture, vUv + vec2( 1.0*px.x, 1.0*px.y))*0.147308056121;",
"gl_FragColor += texture2D(texture, vUv + vec2( 2.0*px.x, 2.0*px.y))*0.115876621105;",
"gl_FragColor += texture2D(texture, vUv + vec2( 3.0*px.x, 3.0*px.y))*0.0776744219933;",
"gl_FragColor += texture2D(texture, vUv + vec2( 4.0*px.x, 4.0*px.y))*0.0443683338718;",
"gl_FragColor += texture2D(texture, vUv + vec2( 5.0*px.x, 5.0*px.y))*0.0215963866053;",
"gl_FragColor += texture2D(texture, vUv + vec2( 6.0*px.x, 6.0*px.y))*0.00895781211794;",
"gl_FragColor += texture2D(texture, vUv + vec2( 7.0*px.x, 7.0*px.y))*0.0044299121055113265;",
"}"
].join("\n");
_filter.pixelate = function(size) {
const blurSizeX = size / _width;
const blurSizeY = size / _height;
const program = _compileShader(_filter.pixelate.SHADER);
gl.uniform2f(program.uniform.size, blurSizeX, blurSizeY);
_draw();
};
_filter.pixelate.SHADER = [
"precision highp float;",
"varying vec2 vUv;",
"uniform vec2 size;",
"uniform sampler2D texture;",
"vec2 pixelate(vec2 coord, vec2 size) {",
"return floor( coord / size ) * size;",
"}",
"void main(void) {",
"gl_FragColor = vec4(0.0);",
"vec2 coord = pixelate(vUv, size);",
"gl_FragColor += texture2D(texture, coord);",
"}"
].join("\n");
};
exports.Canvas = WebGLImageFilter;
});
// config.js
var require_config = __commonJS((exports) => {
__export(exports, {
default: () => config_default
});
var config_default = {
backend: "webgl",
console: true,
scoped: false,
videoOptimized: true,
filter: {
enabled: true,
width: 0,
height: 0,
return: true,
brightness: 0,
contrast: 0,
sharpness: 0,
blur: 0,
saturation: 0,
hue: 0,
negative: false,
sepia: false,
vintage: false,
kodachrome: false,
technicolor: false,
polaroid: false,
pixelate: 0
},
face: {
enabled: true,
detector: {
modelPath: "../models/blazeface/back/model.json",
inputSize: 256,
maxFaces: 10,
skipFrames: 10,
minConfidence: 0.5,
iouThreshold: 0.3,
scoreThreshold: 0.7
},
mesh: {
enabled: true,
modelPath: "../models/facemesh/model.json",
inputSize: 192
},
iris: {
enabled: true,
modelPath: "../models/iris/model.json",
enlargeFactor: 2.3,
inputSize: 64
},
age: {
enabled: true,
modelPath: "../models/ssrnet-age/imdb/model.json",
inputSize: 64,
skipFrames: 10
},
gender: {
enabled: true,
minConfidence: 0.8,
modelPath: "../models/ssrnet-gender/imdb/model.json"
},
emotion: {
enabled: true,
inputSize: 64,
minConfidence: 0.5,
skipFrames: 10,
modelPath: "../models/emotion/model.json"
}
},
body: {
enabled: true,
modelPath: "../models/posenet/model.json",
inputResolution: 257,
outputStride: 16,
maxDetections: 10,
scoreThreshold: 0.7,
nmsRadius: 20
},
hand: {
enabled: true,
inputSize: 256,
skipFrames: 10,
minConfidence: 0.5,
iouThreshold: 0.3,
scoreThreshold: 0.7,
enlargeFactor: 1.65,
maxHands: 10,
detector: {
anchors: "../models/handdetect/anchors.json",
modelPath: "../models/handdetect/model.json"
},
skeleton: {
modelPath: "../models/handskeleton/model.json"
}
}
};
});
// package.json
var require_package = __commonJS((exports, module) => {
module.exports = {
name: "@vladmandic/human",
version: "0.4.9",
description: "human: 3D Face Detection, Iris Tracking and Age & Gender Prediction",
sideEffects: false,
main: "dist/human.node.js",
module: "dist/human.esm.js",
browser: "dist/human.esm.js",
author: "Vladimir Mandic <mandic00@live.com>",
bugs: {
url: "https://github.com/vladmandic/human/issues"
},
homepage: "https://github.com/vladmandic/human#readme",
license: "MIT",
engines: {
node: ">=14.0.0"
},
repository: {
type: "git",
url: "git+https://github.com/vladmandic/human.git"
},
dependencies: {},
peerDependencies: {},
devDependencies: {
seedrandom: "^3.0.5",
"@tensorflow/tfjs": "^2.7.0",
"@tensorflow/tfjs-node": "^2.7.0",
"@vladmandic/pilogger": "^0.2.6",
dayjs: "^1.9.4",
esbuild: "^0.7.22",
eslint: "^7.12.1",
"eslint-config-airbnb-base": "^14.2.0",
"eslint-plugin-import": "^2.22.1",
"eslint-plugin-json": "^2.1.2",
"eslint-plugin-node": "^11.1.0",
"eslint-plugin-promise": "^4.2.1",
rimraf: "^3.0.2",
"simple-git": "^2.21.0"
},
scripts: {
start: "node --trace-warnings --unhandled-rejections=strict --trace-uncaught --no-deprecation src/node.js",
lint: "eslint src/*.js demo/*.js",
"build-iife": "esbuild --bundle --platform=browser --sourcemap --target=esnext --format=iife --external:fs --global-name=Human --metafile=dist/human.json --outfile=dist/human.js src/human.js",
"build-esm-bundle": "esbuild --bundle --platform=browser --sourcemap --target=esnext --format=esm --external:fs --metafile=dist/human.esm.json --outfile=dist/human.esm.js src/human.js",
"build-esm-nobundle": "esbuild --bundle --platform=browser --sourcemap --target=esnext --format=esm --external:@tensorflow --external:fs --metafile=dist/human.esm-nobundle.json --outfile=dist/human.esm-nobundle.js src/human.js",
"build-node": "esbuild --bundle --platform=node --sourcemap --target=esnext --format=cjs --metafile=dist/human.node.json --outfile=dist/human.node.js src/human.js",
"build-node-nobundle": "esbuild --bundle --platform=node --sourcemap --target=esnext --format=cjs --external:@tensorflow --metafile=dist/human.node.json --outfile=dist/human.node-nobundle.js src/human.js",
build: "rimraf dist/* && npm run build-iife && npm run build-esm-bundle && npm run build-esm-nobundle && npm run build-node && npm run build-node-nobundle && ls -l dist/",
update: "npm update --depth 20 --force && npm dedupe && npm prune && npm audit",
changelog: "node changelog.js"
},
keywords: [
"tensorflowjs",
"face-detection",
"face-geometry",
"body-tracking",
"hand-tracking",
"iris-tracking",
"age-estimation",
"emotion-detection",
"gender-prediction",
"gesture-recognition"
]
};
});
// src/human.js
var require_human = __commonJS((exports) => {
__export(exports, {
default: () => Human
});
const tf = require_tf_node();
const facemesh = require_facemesh();
const ssrnet = require_ssrnet();
const emotion = require_emotion();
const posenet = require_posenet();
const handpose = require_handpose();
const fxImage = require_imagefx();
const defaults = require_config().default;
const app = require_package();
let first = true;
const override = {
face: {detector: {skipFrames: 0}, age: {skipFrames: 0}, emotion: {skipFrames: 0}},
hand: {skipFrames: 0}
};
const now = () => {
if (typeof performance !== "undefined")
return performance.now();
return parseInt(Number(process.hrtime.bigint()) / 1e3 / 1e3);
};
function mergeDeep(...objects) {
const isObject = (obj) => obj && typeof obj === "object";
return objects.reduce((prev, obj) => {
Object.keys(obj || {}).forEach((key) => {
const pVal = prev[key];
const oVal = obj[key];
if (Array.isArray(pVal) && Array.isArray(oVal)) {
prev[key] = pVal.concat(...oVal);
} else if (isObject(pVal) && isObject(oVal)) {
prev[key] = mergeDeep(pVal, oVal);
} else {
prev[key] = oVal;
}
});
return prev;
}, {});
}
function sanity(input) {
if (!input)
return "input is not defined";
if (tf.ENV.flags.IS_NODE && !(input instanceof tf.Tensor)) {
return "input must be a tensor";
}
try {
tf.getBackend();
} catch {
return "backend not loaded";
}
return null;
}
class Human {
constructor() {
this.tf = tf;
this.version = app.version;
this.defaults = defaults;
this.config = defaults;
this.fx = tf.ENV.flags.IS_BROWSER && typeof document !== "undefined" ? new fxImage.Canvas() : null;
this.state = "idle";
this.numTensors = 0;
this.analyzeMemoryLeaks = false;
this.models = {
facemesh: null,
posenet: null,
handpose: null,
iris: null,
age: null,
gender: null,
emotion: null
};
this.facemesh = facemesh;
this.ssrnet = ssrnet;
this.emotion = emotion;
this.posenet = posenet;
this.handpose = handpose;
}
log(...msg) {
if (msg && this.config.console)
console.log(...msg);
}
analyze(...msg) {
if (!this.analyzeMemoryLeaks)
return;
const current = tf.engine().state.numTensors;
const previous = this.numTensors;
this.numTensors = current;
const leaked = current - previous;
if (leaked !== 0)
this.log(...msg, leaked);
}
async load(userConfig) {
if (userConfig)
this.config = mergeDeep(defaults, userConfig);
if (this.config.face.enabled && !this.models.facemesh) {
this.log("Load model: Face");
this.models.facemesh = await facemesh.load(this.config.face);
}
if (this.config.body.enabled && !this.models.posenet) {
this.log("Load model: Body");
this.models.posenet = await posenet.load(this.config.body);
}
if (this.config.hand.enabled && !this.models.handpose) {
this.log("Load model: Hand");
this.models.handpose = await handpose.load(this.config.hand);
}
if (this.config.face.enabled && this.config.face.age.enabled && !this.models.age) {
this.log("Load model: Age");
this.models.age = await ssrnet.loadAge(this.config);
}
if (this.config.face.enabled && this.config.face.gender.enabled && !this.models.gender) {
this.log("Load model: Gender");
this.models.gender = await ssrnet.loadGender(this.config);
}
if (this.config.face.enabled && this.config.face.emotion.enabled && !this.models.emotion) {
this.log("Load model: Emotion");
this.models.emotion = await emotion.load(this.config);
}
}
tfImage(input) {
let filtered;
if (this.fx && this.config.filter.enabled && !(input instanceof tf.Tensor)) {
const originalWidth = input.naturalWidth || input.videoWidth || input.width || input.shape && input.shape[1] > 0;
const originalHeight = input.naturalHeight || input.videoHeight || input.height || input.shape && input.shape[2] > 0;
let targetWidth = originalWidth;
if (this.config.filter.width > 0)
targetWidth = this.config.filter.width;
else if (this.config.filter.height > 0)
targetWidth = originalWidth * (this.config.filter.height / originalHeight);
let targetHeight = originalHeight;
if (this.config.filter.height > 0)
targetHeight = this.config.filter.height;
else if (this.config.filter.width > 0)
targetHeight = originalHeight * (this.config.filter.width / originalWidth);
const offscreenCanvas = typeof OffscreenCanvas !== "undefined" ? new OffscreenCanvas(targetWidth, targetHeight) : document.createElement("canvas");
offscreenCanvas.width = targetWidth;
offscreenCanvas.height = targetHeight;
const ctx = offscreenCanvas.getContext("2d");
if (input instanceof ImageData)
ctx.putImageData(input, 0, 0);
else
ctx.drawImage(input, 0, 0, originalWidth, originalHeight, 0, 0, offscreenCanvas.width, offscreenCanvas.height);
this.fx.reset();
this.fx.addFilter("brightness", this.config.filter.brightness);
if (this.config.filter.contrast !== 0)
this.fx.addFilter("contrast", this.config.filter.contrast);
if (this.config.filter.sharpness !== 0)
this.fx.addFilter("sharpen", this.config.filter.sharpness);
if (this.config.filter.blur !== 0)
this.fx.addFilter("blur", this.config.filter.blur);
if (this.config.filter.saturation !== 0)
this.fx.addFilter("saturation", this.config.filter.saturation);
if (this.config.filter.hue !== 0)
this.fx.addFilter("hue", this.config.filter.hue);
if (this.config.filter.negative)
this.fx.addFilter("negative");
if (this.config.filter.sepia)
this.fx.addFilter("sepia");
if (this.config.filter.vintage)
this.fx.addFilter("brownie");
if (this.config.filter.sepia)
this.fx.addFilter("sepia");
if (this.config.filter.kodachrome)
this.fx.addFilter("kodachrome");
if (this.config.filter.technicolor)
this.fx.addFilter("technicolor");
if (this.config.filter.polaroid)
this.fx.addFilter("polaroid");
if (this.config.filter.pixelate !== 0)
this.fx.addFilter("pixelate", this.config.filter.pixelate);
filtered = this.fx.apply(offscreenCanvas);
}
let tensor;
if (input instanceof tf.Tensor) {
tensor = tf.clone(input);
} else {
const pixels = tf.browser.fromPixels(filtered || input);
const casted = pixels.toFloat();
tensor = casted.expandDims(0);
pixels.dispose();
casted.dispose();
}
return {tensor, canvas: this.config.filter.return ? filtered : null};
}
async detect(input, userConfig = {}) {
this.state = "config";
const perf = {};
let timeStamp;
this.config = mergeDeep(defaults, userConfig);
if (!this.config.videoOptimized)
this.config = mergeDeep(this.config, override);
this.state = "check";
const error = sanity(input);
if (error) {
this.log(error, input);
return {error};
}
return new Promise(async (resolve) => {
const timeStart = now();
timeStamp = now();
if (tf.getBackend() !== this.config.backend) {
this.state = "backend";
this.log("Human library setting backend:", this.config.backend);
await tf.setBackend(this.config.backend);
await tf.ready();
}
perf.backend = Math.trunc(now() - timeStamp);
if (first) {
this.log("Human library starting");
this.log("Configuration:", this.config);
this.log("Flags:", tf.ENV.flags);
first = false;
}
timeStamp = now();
this.state = "load";
await this.load();
perf.load = Math.trunc(now() - timeStamp);
if (this.config.scoped)
tf.engine().startScope();
this.analyze("Start Detect:");
timeStamp = now();
const image = this.tfImage(input);
perf.image = Math.trunc(now() - timeStamp);
const imageTensor = image.tensor;
this.state = "run:body";
timeStamp = now();
this.analyze("Start PoseNet");
const poseRes = this.config.body.enabled ? await this.models.posenet.estimatePoses(imageTensor, this.config.body) : [];
this.analyze("End PoseNet:");
perf.body = Math.trunc(now() - timeStamp);
this.state = "run:hand";
timeStamp = now();
this.analyze("Start HandPose:");
const handRes = this.config.hand.enabled ? await this.models.handpose.estimateHands(imageTensor, this.config.hand) : [];
this.analyze("End HandPose:");
perf.hand = Math.trunc(now() - timeStamp);
const faceRes = [];
if (this.config.face.enabled) {
this.state = "run:face";
timeStamp = now();
this.analyze("Start FaceMesh:");
const faces = await this.models.facemesh.estimateFaces(imageTensor, this.config.face);
perf.face = Math.trunc(now() - timeStamp);
for (const face of faces) {
if (!face.image || face.image.isDisposedInternal) {
this.log("face object is disposed:", face.image);
continue;
}
this.state = "run:agegender";
timeStamp = now();
const ssrData = this.config.face.age.enabled || this.config.face.gender.enabled ? await ssrnet.predict(face.image, this.config) : {};
perf.agegender = Math.trunc(now() - timeStamp);
this.state = "run:emotion";
timeStamp = now();
const emotionData = this.config.face.emotion.enabled ? await emotion.predict(face.image, this.config) : {};
perf.emotion = Math.trunc(now() - timeStamp);
face.image.dispose();
const iris = face.annotations.leftEyeIris && face.annotations.rightEyeIris ? Math.max(face.annotations.leftEyeIris[3][0] - face.annotations.leftEyeIris[1][0], face.annotations.rightEyeIris[3][0] - face.annotations.rightEyeIris[1][0]) : 0;
faceRes.push({
confidence: face.confidence,
box: face.box,
mesh: face.mesh,
annotations: face.annotations,
age: ssrData.age,
gender: ssrData.gender,
agConfidence: ssrData.confidence,
emotion: emotionData,
iris: iris !== 0 ? Math.trunc(100 * 11.7 / iris) / 100 : 0
});
this.analyze("End FaceMesh:");
}
}
imageTensor.dispose();
this.state = "idle";
if (this.config.scoped)
tf.engine().endScope();
this.analyze("End Scope:");
perf.total = Math.trunc(now() - timeStart);
resolve({face: faceRes, body: poseRes, hand: handRes, performance: perf, canvas: image.canvas});
});
}
}
});
return require_human();
})();
//# sourceMappingURL=human.js.map