diff --git a/config.js b/config.js
index 716e6766..c49c1999 100644
--- a/config.js
+++ b/config.js
@@ -113,6 +113,7 @@ export default {
scoreThreshold: 0.8, // threshold for deciding when to remove boxes based on score in non-maximum suppression
enlargeFactor: 1.65, // empiric tuning as skeleton prediction prefers hand box with some whitespace
maxHands: 1, // maximum number of hands detected in the input, should be set to the minimum number for performance
+ landmarks: true, // detect hand landmarks or just hand boundary box
detector: {
modelPath: '../models/handdetect.json',
},
diff --git a/demo/browser.js b/demo/browser.js
index 3fa25bc7..a5e436e1 100644
--- a/demo/browser.js
+++ b/demo/browser.js
@@ -27,6 +27,10 @@ const ui = {
maxFrames: 10,
modelsPreload: true,
modelsWarmup: true,
+ menuWidth: 0,
+ menuHeight: 0,
+ camera: {},
+ fps: [],
};
// global variables
@@ -34,8 +38,6 @@ let menu;
let menuFX;
let worker;
let timeStamp;
-let camera = {};
-const fps = [];
// helper function: translates json to human readable string
function str(...msg) {
@@ -62,17 +64,22 @@ const status = (msg) => {
// draws processed results and starts processing of a next frame
function drawResults(input, result, canvas) {
// update fps data
- fps.push(1000 / (performance.now() - timeStamp));
- if (fps.length > ui.maxFrames) fps.shift();
+ const elapsed = performance.now() - timeStamp;
+ ui.fps.push(1000 / elapsed);
+ if (ui.fps.length > ui.maxFrames) ui.fps.shift();
// enable for continous performance monitoring
// console.log(result.performance);
- // eslint-disable-next-line no-use-before-define
- if (input.srcObject) requestAnimationFrame(() => runHumanDetect(input, canvas)); // immediate loop before we even draw results
-
+ // immediate loop before we even draw results, but limit frame rate to 30
+ if (input.srcObject) {
+ // eslint-disable-next-line no-use-before-define
+ if (elapsed > 33) requestAnimationFrame(() => runHumanDetect(input, canvas));
+ // eslint-disable-next-line no-use-before-define
+ else setTimeout(() => runHumanDetect(input, canvas), 33 - elapsed);
+ }
// draw fps chart
- menu.updateChart('FPS', fps);
+ menu.updateChart('FPS', ui.fps);
// draw image from video
const ctx = canvas.getContext('2d');
ctx.fillStyle = ui.baseBackground;
@@ -94,9 +101,9 @@ function drawResults(input, result, canvas) {
const gpu = engine.backendInstance ? `gpu: ${(engine.backendInstance.numBytesInGPU ? engine.backendInstance.numBytesInGPU : 0).toLocaleString()} bytes` : '';
const memory = `system: ${engine.state.numBytes.toLocaleString()} bytes ${gpu} | tensors: ${engine.state.numTensors.toLocaleString()}`;
const processing = result.canvas ? `processing: ${result.canvas.width} x ${result.canvas.height}` : '';
- const avg = Math.trunc(10 * fps.reduce((a, b) => a + b) / fps.length) / 10;
+ const avg = Math.trunc(10 * ui.fps.reduce((a, b) => a + b) / ui.fps.length) / 10;
document.getElementById('log').innerText = `
- video: ${camera.name} | facing: ${camera.facing} | resolution: ${camera.width} x ${camera.height} ${processing}
+ video: ${ui.camera.name} | facing: ${ui.camera.facing} | resolution: ${ui.camera.width} x ${ui.camera.height} ${processing}
backend: ${human.tf.getBackend()} | ${memory}
performance: ${str(result.performance)} FPS:${avg}
`;
@@ -147,7 +154,7 @@ async function setupCamera() {
const track = stream.getVideoTracks()[0];
const settings = track.getSettings();
log('camera constraints:', constraints, 'window:', { width: window.innerWidth, height: window.innerHeight }, 'settings:', settings, 'track:', track);
- camera = { name: track.label, width: settings.width, height: settings.height, facing: settings.facingMode === 'user' ? 'front' : 'back' };
+ ui.camera = { name: track.label, width: settings.width, height: settings.height, facing: settings.facingMode === 'user' ? 'front' : 'back' };
return new Promise((resolve) => {
video.onloadeddata = async () => {
video.width = video.videoWidth;
@@ -156,6 +163,8 @@ async function setupCamera() {
canvas.height = video.height;
canvas.style.width = canvas.width > canvas.height ? '100vw' : '';
canvas.style.height = canvas.width > canvas.height ? '' : '100vh';
+ ui.menuWidth.input.setAttribute('value', video.width);
+ ui.menuHeight.input.setAttribute('value', video.height);
// silly font resizing for paint-on-canvas since viewport can be zoomed
const size = 14 + (6 * canvas.width / window.innerWidth);
ui.baseFont = ui.baseFontProto.replace(/{size}/, `${size}px`);
@@ -351,8 +360,8 @@ function setupMenu() {
menuFX.addHTML('
');
menuFX.addLabel('Image Processing');
menuFX.addBool('Enabled', human.config.filter, 'enabled');
- menuFX.addRange('Image width', human.config.filter, 'width', 0, 3840, 10, (val) => human.config.filter.width = parseInt(val));
- menuFX.addRange('Image height', human.config.filter, 'height', 0, 2160, 10, (val) => human.config.filter.height = parseInt(val));
+ ui.menuWidth = menuFX.addRange('Image width', human.config.filter, 'width', 0, 3840, 10, (val) => human.config.filter.width = parseInt(val));
+ ui.menuHeight = menuFX.addRange('Image height', human.config.filter, 'height', 0, 2160, 10, (val) => human.config.filter.height = parseInt(val));
menuFX.addRange('Brightness', human.config.filter, 'brightness', -1.0, 1.0, 0.05, (val) => human.config.filter.brightness = parseFloat(val));
menuFX.addRange('Contrast', human.config.filter, 'contrast', -1.0, 1.0, 0.05, (val) => human.config.filter.contrast = parseFloat(val));
menuFX.addRange('Sharpness', human.config.filter, 'sharpness', 0, 1.0, 0.05, (val) => human.config.filter.sharpness = parseFloat(val));
diff --git a/demo/menu.js b/demo/menu.js
index 2e68ef1a..a5fb0b51 100644
--- a/demo/menu.js
+++ b/demo/menu.js
@@ -219,6 +219,7 @@ class Menu {
evt.target.setAttribute('value', evt.target.value);
if (callback) callback(evt.target.value);
});
+ el.input = el.children[0];
return el;
}
diff --git a/changelog.js b/dev-server/changelog.js
similarity index 94%
rename from changelog.js
rename to dev-server/changelog.js
index 34cc8fff..e64d6b60 100644
--- a/changelog.js
+++ b/dev-server/changelog.js
@@ -3,7 +3,7 @@ const path = require('path');
const dayjs = require('dayjs');
const simpleGit = require('simple-git/promise');
const logger = require('@vladmandic/pilogger');
-const app = require('./package.json');
+const app = require('../package.json');
const git = simpleGit();
@@ -45,5 +45,5 @@ async function update(f) {
exports.update = update;
if (!module.parent) {
- update('wiki/Change-Log.md');
+ update('../wiki/Change-Log.md');
}
diff --git a/dev-server.crt b/dev-server/dev-server.crt
similarity index 100%
rename from dev-server.crt
rename to dev-server/dev-server.crt
diff --git a/dev-server.js b/dev-server/dev-server.js
similarity index 98%
rename from dev-server.js
rename to dev-server/dev-server.js
index ea065d19..bd823b32 100755
--- a/dev-server.js
+++ b/dev-server/dev-server.js
@@ -26,9 +26,9 @@ const log = require('@vladmandic/pilogger');
const options = {
// key: fs.readFileSync('/home/vlado/dev/piproxy/cert/private.pem'),
// cert: fs.readFileSync('/home/vlado/dev/piproxy/cert/fullchain.pem'),
- key: fs.readFileSync('./dev-server.key'),
- cert: fs.readFileSync('./dev-server.crt'),
- root: '.',
+ key: fs.readFileSync('dev-server/dev-server.key'),
+ cert: fs.readFileSync('dev-server/dev-server.crt'),
+ root: '..',
default: 'demo/index.html',
port: 8000,
monitor: ['package.json', 'config.js', 'demo', 'src'],
diff --git a/dev-server.key b/dev-server/dev-server.key
similarity index 100%
rename from dev-server.key
rename to dev-server/dev-server.key
diff --git a/dist/demo-browser-index.js b/dist/demo-browser-index.js
index 343f506c..84b65ec7 100644
--- a/dist/demo-browser-index.js
+++ b/dist/demo-browser-index.js
@@ -1,18071 +1,43578 @@
// dist/human.esm.js
-var af = Object.defineProperty;
-var hL = (n) => af(n, "__esModule", {value: true});
-var we = (n, t) => () => (t || (t = {exports: {}}, n(t.exports, t)), t.exports);
-var Is = (n, t) => {
- hL(n);
- for (var e in t)
- af(n, e, {get: t[e], enumerable: true});
-};
-var sf = we(() => {
-});
-var of = we(() => {
-});
-var As = we(() => {
-});
-var Qr = we((A) => {
- "use strict";
- Object.defineProperty(A, "__esModule", {value: true});
- var gc = function(n, t) {
- return gc = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(e, r) {
- e.__proto__ = r;
- } || function(e, r) {
- for (var i in r)
- r.hasOwnProperty(i) && (e[i] = r[i]);
- }, gc(n, t);
- };
- function qn(n, t) {
- gc(n, t);
- function e() {
- this.constructor = n;
- }
- n.prototype = t === null ? Object.create(t) : (e.prototype = t.prototype, new e());
+var __defProp = Object.defineProperty;
+var __markAsModule = (target) => __defProp(target, "__esModule", {value: true});
+var __commonJS = (callback, module) => () => {
+ if (!module) {
+ module = {exports: {}};
+ callback(module.exports, module);
}
- function pe(n, t, e, r) {
- return new (e || (e = Promise))(function(i, a) {
- function s(l) {
+ return module.exports;
+};
+var __export = (target, all) => {
+ __markAsModule(target);
+ for (var name in all)
+ __defProp(target, name, {get: all[name], enumerable: true});
+};
+var require_browser = __commonJS(() => {
+});
+var require_util = __commonJS(() => {
+});
+var require_crypto = __commonJS(() => {
+});
+var require_tf_core_node = __commonJS((exports) => {
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ "use strict";
+ Object.defineProperty(exports, "__esModule", {value: true});
+ /*! *****************************************************************************
+ Copyright (c) Microsoft Corporation. All rights reserved.
+ Licensed under the Apache License, Version 2.0 (the "License"); you may not use
+ this file except in compliance with the License. You may obtain a copy of the
+ License at http://www.apache.org/licenses/LICENSE-2.0
+
+ THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+ WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+ MERCHANTABLITY OR NON-INFRINGEMENT.
+
+ See the Apache Version 2.0 License for specific language governing permissions
+ and limitations under the License.
+ ***************************************************************************** */
+ var extendStatics = function(d, b) {
+ extendStatics = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(d2, b2) {
+ d2.__proto__ = b2;
+ } || function(d2, b2) {
+ for (var p in b2)
+ if (b2.hasOwnProperty(p))
+ d2[p] = b2[p];
+ };
+ return extendStatics(d, b);
+ };
+ function __extends(d, b) {
+ extendStatics(d, b);
+ function __() {
+ this.constructor = d;
+ }
+ d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
+ }
+ function __awaiter(thisArg, _arguments, P, generator) {
+ return new (P || (P = Promise))(function(resolve, reject) {
+ function fulfilled(value) {
try {
- c(r.next(l));
- } catch (u) {
- a(u);
+ step2(generator.next(value));
+ } catch (e) {
+ reject(e);
}
}
- function o(l) {
+ function rejected(value) {
try {
- c(r.throw(l));
- } catch (u) {
- a(u);
+ step2(generator["throw"](value));
+ } catch (e) {
+ reject(e);
}
}
- function c(l) {
- l.done ? i(l.value) : new e(function(u) {
- u(l.value);
- }).then(s, o);
+ function step2(result) {
+ result.done ? resolve(result.value) : new P(function(resolve2) {
+ resolve2(result.value);
+ }).then(fulfilled, rejected);
}
- c((r = r.apply(n, t || [])).next());
+ step2((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
- function fe(n, t) {
- var e = {label: 0, sent: function() {
- if (a[0] & 1)
- throw a[1];
- return a[1];
- }, trys: [], ops: []}, r, i, a, s;
- return s = {next: o(0), throw: o(1), return: o(2)}, typeof Symbol == "function" && (s[Symbol.iterator] = function() {
+ function __generator(thisArg, body) {
+ var _ = {label: 0, sent: function() {
+ if (t[0] & 1)
+ throw t[1];
+ return t[1];
+ }, trys: [], ops: []}, f, y, t, g;
+ return g = {next: verb(0), throw: verb(1), return: verb(2)}, typeof Symbol === "function" && (g[Symbol.iterator] = function() {
return this;
- }), s;
- function o(l) {
- return function(u) {
- return c([l, u]);
+ }), g;
+ function verb(n) {
+ return function(v) {
+ return step2([n, v]);
};
}
- function c(l) {
- if (r)
+ function step2(op2) {
+ if (f)
throw new TypeError("Generator is already executing.");
- for (; e; )
+ while (_)
try {
- if (r = 1, i && (a = l[0] & 2 ? i.return : l[0] ? i.throw || ((a = i.return) && a.call(i), 0) : i.next) && !(a = a.call(i, l[1])).done)
- return a;
- (i = 0, a) && (l = [l[0] & 2, a.value]);
- switch (l[0]) {
+ if (f = 1, y && (t = op2[0] & 2 ? y["return"] : op2[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op2[1])).done)
+ return t;
+ if (y = 0, t)
+ op2 = [op2[0] & 2, t.value];
+ switch (op2[0]) {
case 0:
case 1:
- a = l;
+ t = op2;
break;
case 4:
- return e.label++, {value: l[1], done: false};
+ _.label++;
+ return {value: op2[1], done: false};
case 5:
- e.label++, i = l[1], l = [0];
+ _.label++;
+ y = op2[1];
+ op2 = [0];
continue;
case 7:
- l = e.ops.pop(), e.trys.pop();
+ op2 = _.ops.pop();
+ _.trys.pop();
continue;
default:
- if (!(a = e.trys, a = a.length > 0 && a[a.length - 1]) && (l[0] === 6 || l[0] === 2)) {
- e = 0;
+ if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op2[0] === 6 || op2[0] === 2)) {
+ _ = 0;
continue;
}
- if (l[0] === 3 && (!a || l[1] > a[0] && l[1] < a[3])) {
- e.label = l[1];
+ if (op2[0] === 3 && (!t || op2[1] > t[0] && op2[1] < t[3])) {
+ _.label = op2[1];
break;
}
- if (l[0] === 6 && e.label < a[1]) {
- e.label = a[1], a = l;
+ if (op2[0] === 6 && _.label < t[1]) {
+ _.label = t[1];
+ t = op2;
break;
}
- if (a && e.label < a[2]) {
- e.label = a[2], e.ops.push(l);
+ if (t && _.label < t[2]) {
+ _.label = t[2];
+ _.ops.push(op2);
break;
}
- a[2] && e.ops.pop(), e.trys.pop();
+ if (t[2])
+ _.ops.pop();
+ _.trys.pop();
continue;
}
- l = t.call(n, e);
- } catch (u) {
- l = [6, u], i = 0;
+ op2 = body.call(thisArg, _);
+ } catch (e) {
+ op2 = [6, e];
+ y = 0;
} finally {
- r = a = 0;
+ f = t = 0;
}
- if (l[0] & 5)
- throw l[1];
- return {value: l[0] ? l[1] : void 0, done: true};
+ if (op2[0] & 5)
+ throw op2[1];
+ return {value: op2[0] ? op2[1] : void 0, done: true};
}
}
- var dL = 1e-7, pL = 1e-4, fL = function() {
- function n(t, e) {
- this.backend = t, this.dataMover = e, this.data = new WeakMap(), this.dataIdsCount = 0;
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var EPSILON_FLOAT32 = 1e-7;
+ var EPSILON_FLOAT16 = 1e-4;
+ var DataStorage = function() {
+ function DataStorage2(backend2, dataMover) {
+ this.backend = backend2;
+ this.dataMover = dataMover;
+ this.data = new WeakMap();
+ this.dataIdsCount = 0;
}
- return n.prototype.get = function(t) {
- return this.data.has(t) || this.dataMover.moveData(this.backend, t), this.data.get(t);
- }, n.prototype.set = function(t, e) {
- this.dataIdsCount++, this.data.set(t, e);
- }, n.prototype.has = function(t) {
- return this.data.has(t);
- }, n.prototype.delete = function(t) {
- return this.dataIdsCount--, this.data.delete(t);
- }, n.prototype.numDataIds = function() {
+ DataStorage2.prototype.get = function(dataId) {
+ if (!this.data.has(dataId)) {
+ this.dataMover.moveData(this.backend, dataId);
+ }
+ return this.data.get(dataId);
+ };
+ DataStorage2.prototype.set = function(dataId, value) {
+ this.dataIdsCount++;
+ this.data.set(dataId, value);
+ };
+ DataStorage2.prototype.has = function(dataId) {
+ return this.data.has(dataId);
+ };
+ DataStorage2.prototype.delete = function(dataId) {
+ this.dataIdsCount--;
+ return this.data.delete(dataId);
+ };
+ DataStorage2.prototype.numDataIds = function() {
return this.dataIdsCount;
- }, n;
- }(), cf = function() {
- function n() {
- }
- return n.prototype.time = function(t) {
- return X("time");
- }, n.prototype.read = function(t) {
- return X("read");
- }, n.prototype.readSync = function(t) {
- return X("readSync");
- }, n.prototype.numDataIds = function() {
- return X("numDataIds");
- }, n.prototype.disposeData = function(t) {
- return X("disposeData");
- }, n.prototype.write = function(t, e, r) {
- return X("write");
- }, n.prototype.move = function(t, e, r, i) {
- return X("move");
- }, n.prototype.memory = function() {
- return X("memory");
- }, n.prototype.floatPrecision = function() {
- return X("floatPrecision");
- }, n.prototype.epsilon = function() {
- return this.floatPrecision() === 32 ? dL : pL;
- }, n.prototype.batchMatMul = function(t, e, r, i) {
- return X("batchMatMul");
- }, n.prototype.fusedBatchMatMul = function(t) {
- var e = t.a, r = t.b, i = t.transposeA, a = t.transposeB, s = t.bias, o = t.activation, c = t.preluActivationWeights;
- return X("fusedBatchMatMul");
- }, n.prototype.slice = function(t, e, r) {
- return X("slice");
- }, n.prototype.stridedSlice = function(t, e, r, i) {
- return X("stridedSlice");
- }, n.prototype.unstack = function(t, e) {
- return X("unstack");
- }, n.prototype.reverse = function(t, e) {
- return X("reverse");
- }, n.prototype.concat = function(t, e) {
- return X("concat");
- }, n.prototype.neg = function(t) {
- return X("neg");
- }, n.prototype.add = function(t, e) {
- return X("add");
- }, n.prototype.addN = function(t) {
- return X("addN");
- }, n.prototype.subtract = function(t, e) {
- return X("subtract");
- }, n.prototype.multiply = function(t, e) {
- return X("multiply");
- }, n.prototype.realDivide = function(t, e) {
- return X("realDivide");
- }, n.prototype.floorDiv = function(t, e) {
- return X("floorDiv");
- }, n.prototype.sum = function(t, e) {
- return X("sum");
- }, n.prototype.prod = function(t, e) {
- return X("prod");
- }, n.prototype.unsortedSegmentSum = function(t, e, r) {
- return X("unsortedSegmentSum");
- }, n.prototype.argMin = function(t, e) {
- return X("argMin");
- }, n.prototype.argMax = function(t, e) {
- return X("argMax");
- }, n.prototype.equal = function(t, e) {
- return X("equal");
- }, n.prototype.notEqual = function(t, e) {
- return X("notEqual");
- }, n.prototype.less = function(t, e) {
- return X("less");
- }, n.prototype.lessEqual = function(t, e) {
- return X("lessEqual");
- }, n.prototype.greater = function(t, e) {
- return X("greater");
- }, n.prototype.greaterEqual = function(t, e) {
- return X("greaterEqual");
- }, n.prototype.logicalNot = function(t) {
- return X("logicalNot");
- }, n.prototype.logicalAnd = function(t, e) {
- return X("logicalAnd");
- }, n.prototype.logicalOr = function(t, e) {
- return X("logicalOr");
- }, n.prototype.where = function(t) {
- return X("where");
- }, n.prototype.select = function(t, e, r) {
- return X("select");
- }, n.prototype.topk = function(t, e, r) {
- return X("topk");
- }, n.prototype.min = function(t, e) {
- return X("min");
- }, n.prototype.minimum = function(t, e) {
- return X("minimum");
- }, n.prototype.mod = function(t, e) {
- return X("mod");
- }, n.prototype.max = function(t, e) {
- return X("max");
- }, n.prototype.maximum = function(t, e) {
- return X("maximum");
- }, n.prototype.all = function(t, e) {
- return X("all");
- }, n.prototype.any = function(t, e) {
- return X("any");
- }, n.prototype.squaredDifference = function(t, e) {
- return X("squaredDifference");
- }, n.prototype.ceil = function(t) {
- return X("ceil");
- }, n.prototype.floor = function(t) {
- return X("floor");
- }, n.prototype.round = function(t) {
- return X("round");
- }, n.prototype.sign = function(t) {
- return X("sign");
- }, n.prototype.isNaN = function(t) {
- return X("isNaN");
- }, n.prototype.isInf = function(t) {
- return X("isInf");
- }, n.prototype.isFinite = function(t) {
- return X("isFinite");
- }, n.prototype.pow = function(t, e) {
- return X("pow");
- }, n.prototype.exp = function(t) {
- return X("exp");
- }, n.prototype.expm1 = function(t) {
- return X("expm1");
- }, n.prototype.softmax = function(t, e) {
- return X("softmax");
- }, n.prototype.log = function(t) {
- return X("log");
- }, n.prototype.log1p = function(t) {
- return X("log1p");
- }, n.prototype.sqrt = function(t) {
- return X("sqrt");
- }, n.prototype.rsqrt = function(t) {
- return X("rsqrt");
- }, n.prototype.square = function(t) {
- return X("square");
- }, n.prototype.reciprocal = function(t) {
- return X("reciprocal");
- }, n.prototype.relu = function(t) {
- return X("relu");
- }, n.prototype.relu6 = function(t) {
- return X("relu6");
- }, n.prototype.prelu = function(t, e) {
- return X("prelu");
- }, n.prototype.elu = function(t) {
- return X("elu");
- }, n.prototype.eluDer = function(t, e) {
- return X("eluDer");
- }, n.prototype.selu = function(t) {
- return X("selu");
- }, n.prototype.int = function(t) {
- return X("int");
- }, n.prototype.clip = function(t, e, r) {
- return X("clip");
- }, n.prototype.abs = function(t) {
- return X("abs");
- }, n.prototype.complexAbs = function(t) {
- return X("complexAbs");
- }, n.prototype.sigmoid = function(t) {
- return X("sigmoid");
- }, n.prototype.softplus = function(t) {
- return X("softplus");
- }, n.prototype.sin = function(t) {
- return X("sin");
- }, n.prototype.cos = function(t) {
- return X("cos");
- }, n.prototype.tan = function(t) {
- return X("tan");
- }, n.prototype.asin = function(t) {
- return X("asin");
- }, n.prototype.acos = function(t) {
- return X("acos");
- }, n.prototype.atan = function(t) {
- return X("atan");
- }, n.prototype.atan2 = function(t, e) {
- return X("atan2");
- }, n.prototype.sinh = function(t) {
- return X("sinh");
- }, n.prototype.cosh = function(t) {
- return X("cosh");
- }, n.prototype.tanh = function(t) {
- return X("tanh");
- }, n.prototype.asinh = function(t) {
- return X("asinh");
- }, n.prototype.acosh = function(t) {
- return X("acosh");
- }, n.prototype.atanh = function(t) {
- return X("atanh");
- }, n.prototype.erf = function(t) {
- return X("erf");
- }, n.prototype.step = function(t, e) {
- return X("step");
- }, n.prototype.fusedConv2d = function(t) {
- var e = t.input, r = t.filter, i = t.convInfo, a = t.bias, s = t.activation, o = t.preluActivationWeights;
- return X("fusedConv2d");
- }, n.prototype.conv2d = function(t, e, r) {
- return X("conv2d");
- }, n.prototype.conv2dDerInput = function(t, e, r) {
- return X("conv2dDerInput");
- }, n.prototype.conv2dDerFilter = function(t, e, r) {
- return X("conv2dDerFilter");
- }, n.prototype.fusedDepthwiseConv2D = function(t) {
- var e = t.input, r = t.filter, i = t.convInfo, a = t.bias, s = t.activation, o = t.preluActivationWeights;
- return X("fusedDepthwiseConv2D");
- }, n.prototype.depthwiseConv2D = function(t, e, r) {
- return X("depthwiseConv2D");
- }, n.prototype.depthwiseConv2DDerInput = function(t, e, r) {
- return X("depthwiseConv2DDerInput");
- }, n.prototype.depthwiseConv2DDerFilter = function(t, e, r) {
- return X("depthwiseConv2DDerFilter");
- }, n.prototype.conv3d = function(t, e, r) {
- return X("conv3d");
- }, n.prototype.conv3dDerInput = function(t, e, r) {
- return X("conv3dDerInput");
- }, n.prototype.conv3dDerFilter = function(t, e, r) {
- return X("conv3dDerFilter");
- }, n.prototype.maxPool = function(t, e) {
- return X("maxPool");
- }, n.prototype.maxPoolBackprop = function(t, e, r, i) {
- return X("maxPoolBackprop");
- }, n.prototype.avgPool = function(t, e) {
- return X("avgPool");
- }, n.prototype.avgPoolBackprop = function(t, e, r) {
- return X("avgPoolBackprop");
- }, n.prototype.avgPool3d = function(t, e) {
- return X("avgPool3d");
- }, n.prototype.avgPool3dBackprop = function(t, e, r) {
- return X("avgPool3dBackprop");
- }, n.prototype.maxPool3d = function(t, e) {
- return X("maxPool3d");
- }, n.prototype.maxPool3dBackprop = function(t, e, r, i) {
- return X("maxPool3dBackprop");
- }, n.prototype.reshape = function(t, e) {
- return X("reshape");
- }, n.prototype.cast = function(t, e) {
- return X("cast");
- }, n.prototype.tile = function(t, e) {
- return X("tile");
- }, n.prototype.pad = function(t, e, r) {
- return X("pad");
- }, n.prototype.transpose = function(t, e) {
- return X("transpose");
- }, n.prototype.gather = function(t, e, r) {
- return X("gather");
- }, n.prototype.gatherND = function(t, e) {
- return X("gatherND");
- }, n.prototype.scatterND = function(t, e, r) {
- return X("scatterND");
- }, n.prototype.batchToSpaceND = function(t, e, r) {
- return X("batchToSpaceND");
- }, n.prototype.spaceToBatchND = function(t, e, r) {
- return X("spaceToBatchND");
- }, n.prototype.resizeBilinear = function(t, e, r, i) {
- return X("resizeBilinear");
- }, n.prototype.resizeBilinearBackprop = function(t, e, r) {
- return X("resizeBilinearBackprop");
- }, n.prototype.resizeNearestNeighbor = function(t, e, r, i) {
- return X("resizeNearestNeighbor");
- }, n.prototype.resizeNearestNeighborBackprop = function(t, e, r) {
- return X("resizeNearestNeighborBackprop");
- }, n.prototype.batchNorm = function(t, e, r, i, a, s) {
- return X("batchNorm");
- }, n.prototype.localResponseNormalization4D = function(t, e, r, i, a) {
- return X("localResponseNormalization4D");
- }, n.prototype.LRNGrad = function(t, e, r, i, a, s, o) {
- return X("LRNGrad");
- }, n.prototype.multinomial = function(t, e, r, i) {
- return X("multinomial");
- }, n.prototype.oneHot = function(t, e, r, i) {
- return X("oneHot");
- }, n.prototype.cumsum = function(t, e, r, i) {
- return X("cumsum");
- }, n.prototype.nonMaxSuppression = function(t, e, r, i, a) {
- return X("nonMaxSuppression");
- }, n.prototype.fft = function(t) {
- return X("fft");
- }, n.prototype.ifft = function(t) {
- return X("ifft");
- }, n.prototype.complex = function(t, e) {
- return X("complex");
- }, n.prototype.real = function(t) {
- return X("real");
- }, n.prototype.imag = function(t) {
- return X("imag");
- }, n.prototype.cropAndResize = function(t, e, r, i, a, s) {
- return X("cropAndResize");
- }, n.prototype.depthToSpace = function(t, e, r) {
- return X("depthToSpace");
- }, n.prototype.split = function(t, e, r) {
- return X("split");
- }, n.prototype.sparseToDense = function(t, e, r, i) {
- return X("sparseToDense");
- }, n.prototype.diag = function(t) {
- return X("diag");
- }, n.prototype.fill = function(t, e, r) {
- return X("fill");
- }, n.prototype.onesLike = function(t) {
- return X("onesLike");
- }, n.prototype.zerosLike = function(t) {
- return X("zerosLike");
- }, n.prototype.linspace = function(t, e, r) {
- return X("linspace");
- }, n.prototype.dispose = function() {
- return X("dispose");
- }, n;
+ };
+ return DataStorage2;
}();
- function X(n) {
- throw new Error("'" + n + "' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen");
- }
- function lf(n) {
- for (var t = n.length, e = 0, r = 0; t > 0; )
- r = Math.random() * t | 0, t--, e = n[t], n[t] = n[r], n[r] = e;
- }
- function ga(n, t, e) {
- return Math.max(n, Math.min(t, e));
- }
- function mL(n) {
- return n % 2 === 0 ? n : n + 1;
- }
- function gL(n) {
- for (var t = 0, e = 0; e < n.length; e++)
- t += n[e];
- return t;
- }
- function yL(n, t) {
- var e = Math.random();
- return t * e + (1 - e) * n;
- }
- function vL(n, t) {
- for (var e = 0, r = 0; r < n.length; r++) {
- var i = Number(n[r]) - Number(t[r]);
- e += i * i;
+ var KernelBackend = function() {
+ function KernelBackend2() {
}
- return e;
+ KernelBackend2.prototype.time = function(f) {
+ return notYetImplemented("time");
+ };
+ KernelBackend2.prototype.read = function(dataId) {
+ return notYetImplemented("read");
+ };
+ KernelBackend2.prototype.readSync = function(dataId) {
+ return notYetImplemented("readSync");
+ };
+ KernelBackend2.prototype.numDataIds = function() {
+ return notYetImplemented("numDataIds");
+ };
+ KernelBackend2.prototype.disposeData = function(dataId) {
+ return notYetImplemented("disposeData");
+ };
+ KernelBackend2.prototype.write = function(values, shape, dtype) {
+ return notYetImplemented("write");
+ };
+ KernelBackend2.prototype.move = function(dataId, values, shape, dtype) {
+ return notYetImplemented("move");
+ };
+ KernelBackend2.prototype.memory = function() {
+ return notYetImplemented("memory");
+ };
+ KernelBackend2.prototype.floatPrecision = function() {
+ return notYetImplemented("floatPrecision");
+ };
+ KernelBackend2.prototype.epsilon = function() {
+ return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
+ };
+ KernelBackend2.prototype.batchMatMul = function(a, b, transposeA, transposeB) {
+ return notYetImplemented("batchMatMul");
+ };
+ KernelBackend2.prototype.fusedBatchMatMul = function(_a) {
+ var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
+ return notYetImplemented("fusedBatchMatMul");
+ };
+ KernelBackend2.prototype.slice = function(x, begin, size) {
+ return notYetImplemented("slice");
+ };
+ KernelBackend2.prototype.stridedSlice = function(x, begin, end, strides) {
+ return notYetImplemented("stridedSlice");
+ };
+ KernelBackend2.prototype.unstack = function(x, axis) {
+ return notYetImplemented("unstack");
+ };
+ KernelBackend2.prototype.reverse = function(a, axis) {
+ return notYetImplemented("reverse");
+ };
+ KernelBackend2.prototype.concat = function(tensors, axis) {
+ return notYetImplemented("concat");
+ };
+ KernelBackend2.prototype.neg = function(a) {
+ return notYetImplemented("neg");
+ };
+ KernelBackend2.prototype.add = function(a, b) {
+ return notYetImplemented("add");
+ };
+ KernelBackend2.prototype.addN = function(tensors) {
+ return notYetImplemented("addN");
+ };
+ KernelBackend2.prototype.subtract = function(a, b) {
+ return notYetImplemented("subtract");
+ };
+ KernelBackend2.prototype.multiply = function(a, b) {
+ return notYetImplemented("multiply");
+ };
+ KernelBackend2.prototype.realDivide = function(a, b) {
+ return notYetImplemented("realDivide");
+ };
+ KernelBackend2.prototype.floorDiv = function(a, b) {
+ return notYetImplemented("floorDiv");
+ };
+ KernelBackend2.prototype.sum = function(x, axes) {
+ return notYetImplemented("sum");
+ };
+ KernelBackend2.prototype.prod = function(x, axes) {
+ return notYetImplemented("prod");
+ };
+ KernelBackend2.prototype.unsortedSegmentSum = function(x, segmentIds, numSegments) {
+ return notYetImplemented("unsortedSegmentSum");
+ };
+ KernelBackend2.prototype.argMin = function(x, axis) {
+ return notYetImplemented("argMin");
+ };
+ KernelBackend2.prototype.argMax = function(x, axis) {
+ return notYetImplemented("argMax");
+ };
+ KernelBackend2.prototype.equal = function(a, b) {
+ return notYetImplemented("equal");
+ };
+ KernelBackend2.prototype.notEqual = function(a, b) {
+ return notYetImplemented("notEqual");
+ };
+ KernelBackend2.prototype.less = function(a, b) {
+ return notYetImplemented("less");
+ };
+ KernelBackend2.prototype.lessEqual = function(a, b) {
+ return notYetImplemented("lessEqual");
+ };
+ KernelBackend2.prototype.greater = function(a, b) {
+ return notYetImplemented("greater");
+ };
+ KernelBackend2.prototype.greaterEqual = function(a, b) {
+ return notYetImplemented("greaterEqual");
+ };
+ KernelBackend2.prototype.logicalNot = function(a) {
+ return notYetImplemented("logicalNot");
+ };
+ KernelBackend2.prototype.logicalAnd = function(a, b) {
+ return notYetImplemented("logicalAnd");
+ };
+ KernelBackend2.prototype.logicalOr = function(a, b) {
+ return notYetImplemented("logicalOr");
+ };
+ KernelBackend2.prototype.where = function(condition) {
+ return notYetImplemented("where");
+ };
+ KernelBackend2.prototype.select = function(condition, a, b) {
+ return notYetImplemented("select");
+ };
+ KernelBackend2.prototype.topk = function(x, k, sorted) {
+ return notYetImplemented("topk");
+ };
+ KernelBackend2.prototype.min = function(x, axes) {
+ return notYetImplemented("min");
+ };
+ KernelBackend2.prototype.minimum = function(a, b) {
+ return notYetImplemented("minimum");
+ };
+ KernelBackend2.prototype.mod = function(a, b) {
+ return notYetImplemented("mod");
+ };
+ KernelBackend2.prototype.max = function(x, axes) {
+ return notYetImplemented("max");
+ };
+ KernelBackend2.prototype.maximum = function(a, b) {
+ return notYetImplemented("maximum");
+ };
+ KernelBackend2.prototype.all = function(x, axes) {
+ return notYetImplemented("all");
+ };
+ KernelBackend2.prototype.any = function(x, axes) {
+ return notYetImplemented("any");
+ };
+ KernelBackend2.prototype.squaredDifference = function(a, b) {
+ return notYetImplemented("squaredDifference");
+ };
+ KernelBackend2.prototype.ceil = function(x) {
+ return notYetImplemented("ceil");
+ };
+ KernelBackend2.prototype.floor = function(x) {
+ return notYetImplemented("floor");
+ };
+ KernelBackend2.prototype.round = function(x) {
+ return notYetImplemented("round");
+ };
+ KernelBackend2.prototype.sign = function(x) {
+ return notYetImplemented("sign");
+ };
+ KernelBackend2.prototype.isNaN = function(x) {
+ return notYetImplemented("isNaN");
+ };
+ KernelBackend2.prototype.isInf = function(x) {
+ return notYetImplemented("isInf");
+ };
+ KernelBackend2.prototype.isFinite = function(x) {
+ return notYetImplemented("isFinite");
+ };
+ KernelBackend2.prototype.pow = function(a, b) {
+ return notYetImplemented("pow");
+ };
+ KernelBackend2.prototype.exp = function(x) {
+ return notYetImplemented("exp");
+ };
+ KernelBackend2.prototype.expm1 = function(x) {
+ return notYetImplemented("expm1");
+ };
+ KernelBackend2.prototype.softmax = function(x, dim) {
+ return notYetImplemented("softmax");
+ };
+ KernelBackend2.prototype.log = function(x) {
+ return notYetImplemented("log");
+ };
+ KernelBackend2.prototype.log1p = function(x) {
+ return notYetImplemented("log1p");
+ };
+ KernelBackend2.prototype.sqrt = function(x) {
+ return notYetImplemented("sqrt");
+ };
+ KernelBackend2.prototype.rsqrt = function(x) {
+ return notYetImplemented("rsqrt");
+ };
+ KernelBackend2.prototype.square = function(x) {
+ return notYetImplemented("square");
+ };
+ KernelBackend2.prototype.reciprocal = function(x) {
+ return notYetImplemented("reciprocal");
+ };
+ KernelBackend2.prototype.relu = function(x) {
+ return notYetImplemented("relu");
+ };
+ KernelBackend2.prototype.relu6 = function(x) {
+ return notYetImplemented("relu6");
+ };
+ KernelBackend2.prototype.prelu = function(x, a) {
+ return notYetImplemented("prelu");
+ };
+ KernelBackend2.prototype.elu = function(x) {
+ return notYetImplemented("elu");
+ };
+ KernelBackend2.prototype.eluDer = function(dy, y) {
+ return notYetImplemented("eluDer");
+ };
+ KernelBackend2.prototype.selu = function(x) {
+ return notYetImplemented("selu");
+ };
+ KernelBackend2.prototype.int = function(x) {
+ return notYetImplemented("int");
+ };
+ KernelBackend2.prototype.clip = function(x, min2, max2) {
+ return notYetImplemented("clip");
+ };
+ KernelBackend2.prototype.abs = function(x) {
+ return notYetImplemented("abs");
+ };
+ KernelBackend2.prototype.complexAbs = function(x) {
+ return notYetImplemented("complexAbs");
+ };
+ KernelBackend2.prototype.sigmoid = function(x) {
+ return notYetImplemented("sigmoid");
+ };
+ KernelBackend2.prototype.softplus = function(x) {
+ return notYetImplemented("softplus");
+ };
+ KernelBackend2.prototype.sin = function(x) {
+ return notYetImplemented("sin");
+ };
+ KernelBackend2.prototype.cos = function(x) {
+ return notYetImplemented("cos");
+ };
+ KernelBackend2.prototype.tan = function(x) {
+ return notYetImplemented("tan");
+ };
+ KernelBackend2.prototype.asin = function(x) {
+ return notYetImplemented("asin");
+ };
+ KernelBackend2.prototype.acos = function(x) {
+ return notYetImplemented("acos");
+ };
+ KernelBackend2.prototype.atan = function(x) {
+ return notYetImplemented("atan");
+ };
+ KernelBackend2.prototype.atan2 = function(a, b) {
+ return notYetImplemented("atan2");
+ };
+ KernelBackend2.prototype.sinh = function(x) {
+ return notYetImplemented("sinh");
+ };
+ KernelBackend2.prototype.cosh = function(x) {
+ return notYetImplemented("cosh");
+ };
+ KernelBackend2.prototype.tanh = function(x) {
+ return notYetImplemented("tanh");
+ };
+ KernelBackend2.prototype.asinh = function(x) {
+ return notYetImplemented("asinh");
+ };
+ KernelBackend2.prototype.acosh = function(x) {
+ return notYetImplemented("acosh");
+ };
+ KernelBackend2.prototype.atanh = function(x) {
+ return notYetImplemented("atanh");
+ };
+ KernelBackend2.prototype.erf = function(x) {
+ return notYetImplemented("erf");
+ };
+ KernelBackend2.prototype.step = function(x, alpha) {
+ return notYetImplemented("step");
+ };
+ KernelBackend2.prototype.fusedConv2d = function(_a) {
+ var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
+ return notYetImplemented("fusedConv2d");
+ };
+ KernelBackend2.prototype.conv2d = function(x, filter, convInfo) {
+ return notYetImplemented("conv2d");
+ };
+ KernelBackend2.prototype.conv2dDerInput = function(dy, filter, convInfo) {
+ return notYetImplemented("conv2dDerInput");
+ };
+ KernelBackend2.prototype.conv2dDerFilter = function(x, dY, convInfo) {
+ return notYetImplemented("conv2dDerFilter");
+ };
+ KernelBackend2.prototype.fusedDepthwiseConv2D = function(_a) {
+ var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
+ return notYetImplemented("fusedDepthwiseConv2D");
+ };
+ KernelBackend2.prototype.depthwiseConv2D = function(input, filter, convInfo) {
+ return notYetImplemented("depthwiseConv2D");
+ };
+ KernelBackend2.prototype.depthwiseConv2DDerInput = function(dy, filter, convInfo) {
+ return notYetImplemented("depthwiseConv2DDerInput");
+ };
+ KernelBackend2.prototype.depthwiseConv2DDerFilter = function(x, dY, convInfo) {
+ return notYetImplemented("depthwiseConv2DDerFilter");
+ };
+ KernelBackend2.prototype.conv3d = function(x, filter, convInfo) {
+ return notYetImplemented("conv3d");
+ };
+ KernelBackend2.prototype.conv3dDerInput = function(dy, filter, convInfo) {
+ return notYetImplemented("conv3dDerInput");
+ };
+ KernelBackend2.prototype.conv3dDerFilter = function(x, dY, convInfo) {
+ return notYetImplemented("conv3dDerFilter");
+ };
+ KernelBackend2.prototype.maxPool = function(x, convInfo) {
+ return notYetImplemented("maxPool");
+ };
+ KernelBackend2.prototype.maxPoolBackprop = function(dy, x, y, convInfo) {
+ return notYetImplemented("maxPoolBackprop");
+ };
+ KernelBackend2.prototype.avgPool = function(x, convInfo) {
+ return notYetImplemented("avgPool");
+ };
+ KernelBackend2.prototype.avgPoolBackprop = function(dy, x, convInfo) {
+ return notYetImplemented("avgPoolBackprop");
+ };
+ KernelBackend2.prototype.avgPool3d = function(x, convInfo) {
+ return notYetImplemented("avgPool3d");
+ };
+ KernelBackend2.prototype.avgPool3dBackprop = function(dy, x, convInfo) {
+ return notYetImplemented("avgPool3dBackprop");
+ };
+ KernelBackend2.prototype.maxPool3d = function(x, convInfo) {
+ return notYetImplemented("maxPool3d");
+ };
+ KernelBackend2.prototype.maxPool3dBackprop = function(dy, x, y, convInfo) {
+ return notYetImplemented("maxPool3dBackprop");
+ };
+ KernelBackend2.prototype.reshape = function(x, shape) {
+ return notYetImplemented("reshape");
+ };
+ KernelBackend2.prototype.cast = function(x, dtype) {
+ return notYetImplemented("cast");
+ };
+ KernelBackend2.prototype.tile = function(x, reps) {
+ return notYetImplemented("tile");
+ };
+ KernelBackend2.prototype.pad = function(x, paddings, constantValue) {
+ return notYetImplemented("pad");
+ };
+ KernelBackend2.prototype.transpose = function(x, perm) {
+ return notYetImplemented("transpose");
+ };
+ KernelBackend2.prototype.gather = function(x, indices, axis) {
+ return notYetImplemented("gather");
+ };
+ KernelBackend2.prototype.gatherND = function(x, indices) {
+ return notYetImplemented("gatherND");
+ };
+ KernelBackend2.prototype.scatterND = function(indices, updates, shape) {
+ return notYetImplemented("scatterND");
+ };
+ KernelBackend2.prototype.batchToSpaceND = function(x, blockShape, crops) {
+ return notYetImplemented("batchToSpaceND");
+ };
+ KernelBackend2.prototype.spaceToBatchND = function(x, blockShape, paddings) {
+ return notYetImplemented("spaceToBatchND");
+ };
+ KernelBackend2.prototype.resizeBilinear = function(x, newHeight, newWidth, alignCorners) {
+ return notYetImplemented("resizeBilinear");
+ };
+ KernelBackend2.prototype.resizeBilinearBackprop = function(dy, x, alignCorners) {
+ return notYetImplemented("resizeBilinearBackprop");
+ };
+ KernelBackend2.prototype.resizeNearestNeighbor = function(x, newHEight, newWidth, alignCorners) {
+ return notYetImplemented("resizeNearestNeighbor");
+ };
+ KernelBackend2.prototype.resizeNearestNeighborBackprop = function(dy, x, alignCorners) {
+ return notYetImplemented("resizeNearestNeighborBackprop");
+ };
+ KernelBackend2.prototype.batchNorm = function(x, mean2, variance, offset, scale, varianceEpsilon) {
+ return notYetImplemented("batchNorm");
+ };
+ KernelBackend2.prototype.localResponseNormalization4D = function(x, radius, bias, alpha, beta) {
+ return notYetImplemented("localResponseNormalization4D");
+ };
+ KernelBackend2.prototype.LRNGrad = function(dy, inputImage, outputImage, radius, bias, alpha, beta) {
+ return notYetImplemented("LRNGrad");
+ };
+ KernelBackend2.prototype.multinomial = function(logits, normalized, numSamples, seed) {
+ return notYetImplemented("multinomial");
+ };
+ KernelBackend2.prototype.oneHot = function(indices, depth, onValue, offValue) {
+ return notYetImplemented("oneHot");
+ };
+ KernelBackend2.prototype.cumsum = function(x, axis, exclusive, reverse2) {
+ return notYetImplemented("cumsum");
+ };
+ KernelBackend2.prototype.nonMaxSuppression = function(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
+ return notYetImplemented("nonMaxSuppression");
+ };
+ KernelBackend2.prototype.fft = function(x) {
+ return notYetImplemented("fft");
+ };
+ KernelBackend2.prototype.ifft = function(x) {
+ return notYetImplemented("ifft");
+ };
+ KernelBackend2.prototype.complex = function(real2, imag2) {
+ return notYetImplemented("complex");
+ };
+ KernelBackend2.prototype.real = function(input) {
+ return notYetImplemented("real");
+ };
+ KernelBackend2.prototype.imag = function(input) {
+ return notYetImplemented("imag");
+ };
+ KernelBackend2.prototype.cropAndResize = function(image3, boxes, boxIndex, cropSize, method, extrapolationValue) {
+ return notYetImplemented("cropAndResize");
+ };
+ KernelBackend2.prototype.depthToSpace = function(x, blockSize, dataFormat) {
+ return notYetImplemented("depthToSpace");
+ };
+ KernelBackend2.prototype.split = function(value, sizeSplits, axis) {
+ return notYetImplemented("split");
+ };
+ KernelBackend2.prototype.sparseToDense = function(sparseIndices, sparseValues, outputShape, defaultValue) {
+ return notYetImplemented("sparseToDense");
+ };
+ KernelBackend2.prototype.diag = function(x) {
+ return notYetImplemented("diag");
+ };
+ KernelBackend2.prototype.fill = function(shape, value, dtype) {
+ return notYetImplemented("fill");
+ };
+ KernelBackend2.prototype.onesLike = function(x) {
+ return notYetImplemented("onesLike");
+ };
+ KernelBackend2.prototype.zerosLike = function(x) {
+ return notYetImplemented("zerosLike");
+ };
+ KernelBackend2.prototype.linspace = function(start, stop, num) {
+ return notYetImplemented("linspace");
+ };
+ KernelBackend2.prototype.dispose = function() {
+ return notYetImplemented("dispose");
+ };
+ return KernelBackend2;
+ }();
+ function notYetImplemented(kernelName) {
+ throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen");
}
- function E(n, t) {
- if (!n)
- throw new Error(typeof t == "string" ? t : t());
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function shuffle(array) {
+ var counter = array.length;
+ var temp = 0;
+ var index = 0;
+ while (counter > 0) {
+ index = Math.random() * counter | 0;
+ counter--;
+ temp = array[counter];
+ array[counter] = array[index];
+ array[index] = temp;
+ }
}
- function Pe(n, t, e) {
- e === void 0 && (e = ""), E(pn(n, t), function() {
- return e + (" Shapes " + n + " and " + t + " must match");
+ function clamp(min2, x, max2) {
+ return Math.max(min2, Math.min(x, max2));
+ }
+ function nearestLargerEven(val) {
+ return val % 2 === 0 ? val : val + 1;
+ }
+ function sum(arr) {
+ var sum2 = 0;
+ for (var i = 0; i < arr.length; i++) {
+ sum2 += arr[i];
+ }
+ return sum2;
+ }
+ function randUniform(a, b) {
+ var r = Math.random();
+ return b * r + (1 - r) * a;
+ }
+ function distSquared(a, b) {
+ var result = 0;
+ for (var i = 0; i < a.length; i++) {
+ var diff = Number(a[i]) - Number(b[i]);
+ result += diff * diff;
+ }
+ return result;
+ }
+ function assert(expr, msg) {
+ if (!expr) {
+ throw new Error(typeof msg === "string" ? msg : msg());
+ }
+ }
+ function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) {
+ if (errorMessagePrefix === void 0) {
+ errorMessagePrefix = "";
+ }
+ assert(arraysEqual(shapeA, shapeB), function() {
+ return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match");
});
}
- function Ur(n) {
- E(n != null, function() {
+ function assertNonNull(a) {
+ assert(a != null, function() {
return "The input to the tensor constructor must be a non-null value.";
});
}
- function Br(n, t, e) {
- if (t === void 0 && (t = []), e === void 0 && (e = false), t == null && (t = []), Array.isArray(n) || Ft(n) && !e)
- for (var r = 0; r < n.length; ++r)
- Br(n[r], t, e);
- else
- t.push(n);
- return t;
+ function flatten(arr, result, skipTypedArray) {
+ if (result === void 0) {
+ result = [];
+ }
+ if (skipTypedArray === void 0) {
+ skipTypedArray = false;
+ }
+ if (result == null) {
+ result = [];
+ }
+ if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
+ for (var i = 0; i < arr.length; ++i) {
+ flatten(arr[i], result, skipTypedArray);
+ }
+ } else {
+ result.push(arr);
+ }
+ return result;
}
- function pt(n) {
- if (n.length === 0)
+ function sizeFromShape(shape) {
+ if (shape.length === 0) {
return 1;
- for (var t = n[0], e = 1; e < n.length; e++)
- t *= n[e];
- return t;
+ }
+ var size = shape[0];
+ for (var i = 1; i < shape.length; i++) {
+ size *= shape[i];
+ }
+ return size;
}
- function wL(n) {
- return n.length === 0;
+ function isScalarShape(shape) {
+ return shape.length === 0;
}
- function pn(n, t) {
- if (n === t)
+ function arraysEqual(n1, n2) {
+ if (n1 === n2) {
return true;
- if (n == null || t == null)
+ }
+ if (n1 == null || n2 == null) {
return false;
- if (n.length !== t.length)
+ }
+ if (n1.length !== n2.length) {
return false;
- for (var e = 0; e < n.length; e++)
- if (n[e] !== t[e])
+ }
+ for (var i = 0; i < n1.length; i++) {
+ if (n1[i] !== n2[i]) {
return false;
+ }
+ }
return true;
}
- function ot(n) {
- return n % 1 === 0;
+ function isInt(a) {
+ return a % 1 === 0;
}
- function bL(n) {
- if (Math.tanh != null)
- return Math.tanh(n);
- if (n === Infinity)
+ function tanh(x) {
+ if (Math.tanh != null) {
+ return Math.tanh(x);
+ }
+ if (x === Infinity) {
return 1;
- if (n === -Infinity)
+ } else if (x === -Infinity) {
return -1;
- var t = Math.exp(2 * n);
- return (t - 1) / (t + 1);
+ } else {
+ var e2x = Math.exp(2 * x);
+ return (e2x - 1) / (e2x + 1);
+ }
}
- function xL(n) {
- var t = Math.ceil(Math.sqrt(n));
- return [t, Math.ceil(n / t)];
+ function sizeToSquarishShape(size) {
+ var width = Math.ceil(Math.sqrt(size));
+ return [width, Math.ceil(size / width)];
}
- function LL(n) {
- for (var t = new Uint32Array(n), e = 0; e < n; ++e)
- t[e] = e;
- return lf(t), t;
+ function createShuffledIndices(n) {
+ var shuffledIndices = new Uint32Array(n);
+ for (var i = 0; i < n; ++i) {
+ shuffledIndices[i] = i;
+ }
+ shuffle(shuffledIndices);
+ return shuffledIndices;
}
- function ya(n, t) {
- return t <= n.length ? n : n + " ".repeat(t - n.length);
+ function rightPad(a, size) {
+ if (size <= a.length) {
+ return a;
+ }
+ return a + " ".repeat(size - a.length);
}
- function SL(n, t, e) {
- return t === void 0 && (t = function(r) {
- return 0;
- }), new Promise(function(r, i) {
- var a = 0, s = function() {
- if (n()) {
- r();
- return;
- }
- a++;
- var o = t(a);
- if (e != null && a >= e) {
- i();
- return;
- }
- setTimeout(s, o);
+ function repeatedTry(checkFn, delayFn, maxCounter) {
+ if (delayFn === void 0) {
+ delayFn = function(counter) {
+ return 0;
};
- s();
- });
- }
- function uf(n, t) {
- for (var e = 1, r = -1, i = 0; i < n.length; ++i)
- if (n[i] >= 0)
- e *= n[i];
- else if (n[i] === -1) {
- if (r !== -1)
- throw Error("Shapes can only have 1 implicit size. " + ("Found -1 at dim " + r + " and dim " + i));
- r = i;
- } else if (n[i] < 0)
- throw Error("Shapes can not be < 0. Found " + n[i] + " at dim " + i);
- if (r === -1) {
- if (t > 0 && t !== e)
- throw Error("Size(" + t + ") must match the product of shape " + n);
- return n;
}
- if (e === 0)
- throw Error("Cannot infer the missing size in [" + n + "] when there are 0 elements");
- if (t % e !== 0)
- throw Error("The implicit shape can't be a fractional number. " + ("Got " + t + " / " + e));
- var a = n.slice();
- return a[r] = t / e, a;
- }
- function rt(n, t) {
- var e = t.length;
- return n = n == null ? t.map(function(r, i) {
- return i;
- }) : [].concat(n), E(n.every(function(r) {
- return r >= -e && r < e;
- }), function() {
- return "All values in axis param must be in range [-" + e + ", " + e + ") but " + ("got axis " + n);
- }), E(n.every(function(r) {
- return ot(r);
- }), function() {
- return "All values in axis param must be integers but " + ("got axis " + n);
- }), n.map(function(r) {
- return r < 0 ? e + r : r;
+ return new Promise(function(resolve, reject) {
+ var tryCount = 0;
+ var tryFn = function() {
+ if (checkFn()) {
+ resolve();
+ return;
+ }
+ tryCount++;
+ var nextBackoff = delayFn(tryCount);
+ if (maxCounter != null && tryCount >= maxCounter) {
+ reject();
+ return;
+ }
+ setTimeout(tryFn, nextBackoff);
+ };
+ tryFn();
});
}
- function hf(n, t) {
- for (var e = [], r = [], i = t != null && Array.isArray(t) && t.length === 0, a = t == null || i ? null : rt(t, n).sort(), s = 0, o = 0; o < n.length; ++o) {
- if (a != null) {
- if (a[s] === o && n[o] !== 1)
- throw new Error("Can't squeeze axis " + o + " since its dim '" + n[o] + "' is not 1");
- (a[s] == null || a[s] > o) && n[o] === 1 && (e.push(n[o]), r.push(o)), a[s] <= o && s++;
+ function inferFromImplicitShape(shape, size) {
+ var shapeProd = 1;
+ var implicitIdx = -1;
+ for (var i = 0; i < shape.length; ++i) {
+ if (shape[i] >= 0) {
+ shapeProd *= shape[i];
+ } else if (shape[i] === -1) {
+ if (implicitIdx !== -1) {
+ throw Error("Shapes can only have 1 implicit size. " + ("Found -1 at dim " + implicitIdx + " and dim " + i));
+ }
+ implicitIdx = i;
+ } else if (shape[i] < 0) {
+ throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i);
}
- n[o] !== 1 && (e.push(n[o]), r.push(o));
}
- return {newShape: e, keptDims: r};
- }
- function Ts(n, t) {
- var e = null;
- if (n == null || n === "float32")
- e = new Float32Array(t);
- else if (n === "int32")
- e = new Int32Array(t);
- else if (n === "bool")
- e = new Uint8Array(t);
- else
- throw new Error("Unknown data type " + n);
- return e;
- }
- function df(n, t) {
- var e = null;
- if (n == null || n === "float32")
- e = new Float32Array(t);
- else if (n === "int32")
- e = new Int32Array(t);
- else if (n === "bool")
- e = new Uint8Array(t);
- else if (n === "string")
- e = new Array(t);
- else
- throw new Error("Unknown data type " + n);
- return e;
- }
- function pf(n, t) {
- for (var e = 0; e < n.length; e++) {
- var r = n[e];
- if (isNaN(r) || !isFinite(r))
- throw Error("A tensor of type " + t + " being uploaded contains " + r + ".");
+ if (implicitIdx === -1) {
+ if (size > 0 && size !== shapeProd) {
+ throw Error("Size(" + size + ") must match the product of shape " + shape);
+ }
+ return shape;
}
+ if (shapeProd === 0) {
+ throw Error("Cannot infer the missing size in [" + shape + "] when there are 0 elements");
+ }
+ if (size % shapeProd !== 0) {
+ throw Error("The implicit shape can't be a fractional number. " + ("Got " + size + " / " + shapeProd));
+ }
+ var newShape = shape.slice();
+ newShape[implicitIdx] = size / shapeProd;
+ return newShape;
}
- function ff(n) {
- return n === "bool" || n === "complex64" || n === "float32" || n === "int32" || n === "string";
- }
- function mf(n, t) {
- return t === "complex64" || (t === "float32" && n !== "complex64" || t === "int32" && n !== "float32" && n !== "complex64") ? false : !(t === "bool" && n === "bool");
- }
- function Ft(n) {
- return n instanceof Float32Array || n instanceof Int32Array || n instanceof Uint8Array;
- }
- function gf(n) {
- if (n === "float32" || n === "int32")
- return 4;
- if (n === "complex64")
- return 8;
- if (n === "bool")
- return 1;
- throw new Error("Unknown dtype " + n);
- }
- function yf(n) {
- if (n == null)
- return 0;
- var t = 0;
- return n.forEach(function(e) {
- return t += e.length;
- }), t;
- }
- function or(n) {
- return typeof n == "string" || n instanceof String;
- }
- function vf(n) {
- return typeof n == "boolean";
- }
- function wf(n) {
- return typeof n == "number";
- }
- function Ns(n) {
- return Array.isArray(n) ? Ns(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array ? "int32" : wf(n) ? "float32" : or(n) ? "string" : vf(n) ? "bool" : "float32";
- }
- function cr(n) {
- return !!(n && n.constructor && n.call && n.apply);
- }
- function _s(n, t) {
- for (var e = t; e < n; ++e)
- if (n % e === 0)
- return e;
- return n;
- }
- function bi(n) {
- var t = n.length;
- if (t < 2)
- return [];
- var e = new Array(t - 1);
- e[t - 2] = n[t - 1];
- for (var r = t - 3; r >= 0; --r)
- e[r] = e[r + 1] * n[r + 1];
- return e;
- }
- function bf(n, t, e) {
- var r = new Array();
- if (t.length === 1)
- for (var i = t[0], a = 0; a < i; a++)
- r[a] = e[n + a];
- else
- for (var i = t[0], s = t.slice(1), o = s.reduce(function(l, u) {
- return l * u;
- }), a = 0; a < i; a++)
- r[a] = bf(n + a * o, s, e);
- return r;
- }
- function xi(n, t) {
- if (n.length === 0)
- return t[0];
- var e = n.reduce(function(r, i) {
- return r * i;
+ function parseAxisParam(axis, shape) {
+ var rank = shape.length;
+ axis = axis == null ? shape.map(function(s, i) {
+ return i;
+ }) : [].concat(axis);
+ assert(axis.every(function(ax) {
+ return ax >= -rank && ax < rank;
+ }), function() {
+ return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " + ("got axis " + axis);
});
- if (e === 0)
+ assert(axis.every(function(ax) {
+ return isInt(ax);
+ }), function() {
+ return "All values in axis param must be integers but " + ("got axis " + axis);
+ });
+ return axis.map(function(a) {
+ return a < 0 ? rank + a : a;
+ });
+ }
+ function squeezeShape(shape, axis) {
+ var newShape = [];
+ var keptDims = [];
+ var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
+ var axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort();
+ var j = 0;
+ for (var i = 0; i < shape.length; ++i) {
+ if (axes != null) {
+ if (axes[j] === i && shape[i] !== 1) {
+ throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1");
+ }
+ if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
+ newShape.push(shape[i]);
+ keptDims.push(i);
+ }
+ if (axes[j] <= i) {
+ j++;
+ }
+ }
+ if (shape[i] !== 1) {
+ newShape.push(shape[i]);
+ keptDims.push(i);
+ }
+ }
+ return {newShape, keptDims};
+ }
+ function getTypedArrayFromDType(dtype, size) {
+ var values = null;
+ if (dtype == null || dtype === "float32") {
+ values = new Float32Array(size);
+ } else if (dtype === "int32") {
+ values = new Int32Array(size);
+ } else if (dtype === "bool") {
+ values = new Uint8Array(size);
+ } else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ return values;
+ }
+ function getArrayFromDType(dtype, size) {
+ var values = null;
+ if (dtype == null || dtype === "float32") {
+ values = new Float32Array(size);
+ } else if (dtype === "int32") {
+ values = new Int32Array(size);
+ } else if (dtype === "bool") {
+ values = new Uint8Array(size);
+ } else if (dtype === "string") {
+ values = new Array(size);
+ } else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ return values;
+ }
+ function checkConversionForErrors(vals, dtype) {
+ for (var i = 0; i < vals.length; i++) {
+ var num = vals[i];
+ if (isNaN(num) || !isFinite(num)) {
+ throw Error("A tensor of type " + dtype + " being uploaded contains " + num + ".");
+ }
+ }
+ }
+ function isValidDtype(dtype) {
+ return dtype === "bool" || dtype === "complex64" || dtype === "float32" || dtype === "int32" || dtype === "string";
+ }
+ function hasEncodingLoss(oldType, newType) {
+ if (newType === "complex64") {
+ return false;
+ }
+ if (newType === "float32" && oldType !== "complex64") {
+ return false;
+ }
+ if (newType === "int32" && oldType !== "float32" && oldType !== "complex64") {
+ return false;
+ }
+ if (newType === "bool" && oldType === "bool") {
+ return false;
+ }
+ return true;
+ }
+ function isTypedArray(a) {
+ return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array;
+ }
+ function bytesPerElement(dtype) {
+ if (dtype === "float32" || dtype === "int32") {
+ return 4;
+ } else if (dtype === "complex64") {
+ return 8;
+ } else if (dtype === "bool") {
+ return 1;
+ } else {
+ throw new Error("Unknown dtype " + dtype);
+ }
+ }
+ function bytesFromStringArray(arr) {
+ if (arr == null) {
+ return 0;
+ }
+ var bytes = 0;
+ arr.forEach(function(x) {
+ return bytes += x.length;
+ });
+ return bytes;
+ }
+ function isString(value) {
+ return typeof value === "string" || value instanceof String;
+ }
+ function isBoolean(value) {
+ return typeof value === "boolean";
+ }
+ function isNumber(value) {
+ return typeof value === "number";
+ }
+ function inferDtype(values) {
+ if (Array.isArray(values)) {
+ return inferDtype(values[0]);
+ }
+ if (values instanceof Float32Array) {
+ return "float32";
+ } else if (values instanceof Int32Array || values instanceof Uint8Array) {
+ return "int32";
+ } else if (isNumber(values)) {
+ return "float32";
+ } else if (isString(values)) {
+ return "string";
+ } else if (isBoolean(values)) {
+ return "bool";
+ }
+ return "float32";
+ }
+ function isFunction(f) {
+ return !!(f && f.constructor && f.call && f.apply);
+ }
+ function nearestDivisor(size, start) {
+ for (var i = start; i < size; ++i) {
+ if (size % i === 0) {
+ return i;
+ }
+ }
+ return size;
+ }
+ function computeStrides(shape) {
+ var rank = shape.length;
+ if (rank < 2) {
return [];
- if (e !== t.length)
- throw new Error("[" + n + "] does not match the input size " + t.length + ".");
- return bf(0, n, t);
+ }
+ var strides = new Array(rank - 1);
+ strides[rank - 2] = shape[rank - 1];
+ for (var i = rank - 3; i >= 0; --i) {
+ strides[i] = strides[i + 1] * shape[i + 1];
+ }
+ return strides;
}
- function yc(n, t) {
- for (var e = Li(n, t), r = 0; r < e.length; r++)
- e[r] = 1;
- return e;
+ function createNestedArray(offset, shape, a) {
+ var ret = new Array();
+ if (shape.length === 1) {
+ var d = shape[0];
+ for (var i = 0; i < d; i++) {
+ ret[i] = a[offset + i];
+ }
+ } else {
+ var d = shape[0];
+ var rest = shape.slice(1);
+ var len = rest.reduce(function(acc, c) {
+ return acc * c;
+ });
+ for (var i = 0; i < d; i++) {
+ ret[i] = createNestedArray(offset + i * len, rest, a);
+ }
+ }
+ return ret;
}
- function Li(n, t) {
- if (t == null || t === "float32" || t === "complex64")
- return new Float32Array(n);
- if (t === "int32")
- return new Int32Array(n);
- if (t === "bool")
- return new Uint8Array(n);
- throw new Error("Unknown data type " + t);
+ function toNestedArray(shape, a) {
+ if (shape.length === 0) {
+ return a[0];
+ }
+ var size = shape.reduce(function(acc, c) {
+ return acc * c;
+ });
+ if (size === 0) {
+ return [];
+ }
+ if (size !== a.length) {
+ throw new Error("[" + shape + "] does not match the input size " + a.length + ".");
+ }
+ return createNestedArray(0, shape, a);
}
- function IL(n, t) {
- var e = n.reduce(function(r, i) {
- return r * i;
+ function makeOnesTypedArray(size, dtype) {
+ var array = makeZerosTypedArray(size, dtype);
+ for (var i = 0; i < array.length; i++) {
+ array[i] = 1;
+ }
+ return array;
+ }
+ function makeZerosTypedArray(size, dtype) {
+ if (dtype == null || dtype === "float32" || dtype === "complex64") {
+ return new Float32Array(size);
+ } else if (dtype === "int32") {
+ return new Int32Array(size);
+ } else if (dtype === "bool") {
+ return new Uint8Array(size);
+ } else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ }
+ function makeZerosNestedTypedArray(shape, dtype) {
+ var size = shape.reduce(function(prev, curr) {
+ return prev * curr;
}, 1);
- if (t == null || t === "float32")
- return xi(n, new Float32Array(e));
- if (t === "int32")
- return xi(n, new Int32Array(e));
- if (t === "bool")
- return xi(n, new Uint8Array(e));
- throw new Error("Unknown data type " + t);
+ if (dtype == null || dtype === "float32") {
+ return toNestedArray(shape, new Float32Array(size));
+ } else if (dtype === "int32") {
+ return toNestedArray(shape, new Int32Array(size));
+ } else if (dtype === "bool") {
+ return toNestedArray(shape, new Uint8Array(size));
+ } else {
+ throw new Error("Unknown data type " + dtype);
+ }
}
- function vc(n) {
- n.forEach(function(t) {
- E(Number.isInteger(t) && t >= 0, function() {
- return "Tensor must have a shape comprised of positive integers but got " + ("shape [" + n + "].");
+ function assertNonNegativeIntegerDimensions(shape) {
+ shape.forEach(function(dimSize) {
+ assert(Number.isInteger(dimSize) && dimSize >= 0, function() {
+ return "Tensor must have a shape comprised of positive integers but got " + ("shape [" + shape + "].");
});
});
}
- function AL(n, t, e) {
- if (t === 0)
+ function locToIndex(locs, rank, strides) {
+ if (rank === 0) {
return 0;
- if (t === 1)
- return n[0];
- for (var r = n[n.length - 1], i = 0; i < n.length - 1; ++i)
- r += e[i] * n[i];
- return r;
- }
- function TL(n, t, e) {
- if (t === 0)
- return [];
- if (t === 1)
- return [n];
- for (var r = new Array(t), i = 0; i < r.length - 1; ++i)
- r[i] = Math.floor(n / e[i]), n -= r[i] * e[i];
- return r[r.length - 1] = n, r;
- }
- function wc(n) {
- return n && n.then && typeof n.then == "function";
- }
- var xf = "tfjsflags", Lf = function() {
- function n(t) {
- this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.populateURLFlags();
+ } else if (rank === 1) {
+ return locs[0];
}
- return n.prototype.setPlatform = function(t, e) {
- this.platform != null && console.warn("Platform " + this.platformName + " has already been set. " + ("Overwriting the platform with " + e + ".")), this.platformName = t, this.platform = e;
- }, n.prototype.registerFlag = function(t, e, r) {
- if (this.flagRegistry[t] = {evaluationFn: e, setHook: r}, this.urlFlags[t] != null) {
- var i = this.urlFlags[t];
- console.warn("Setting feature override from URL " + t + ": " + i + "."), this.set(t, i);
+ var index = locs[locs.length - 1];
+ for (var i = 0; i < locs.length - 1; ++i) {
+ index += strides[i] * locs[i];
+ }
+ return index;
+ }
+ function indexToLoc(index, rank, strides) {
+ if (rank === 0) {
+ return [];
+ } else if (rank === 1) {
+ return [index];
+ }
+ var locs = new Array(rank);
+ for (var i = 0; i < locs.length - 1; ++i) {
+ locs[i] = Math.floor(index / strides[i]);
+ index -= locs[i] * strides[i];
+ }
+ locs[locs.length - 1] = index;
+ return locs;
+ }
+ function isPromise(object) {
+ return object && object.then && typeof object.then === "function";
+ }
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var TENSORFLOWJS_FLAGS_PREFIX = "tfjsflags";
+ var Environment = function() {
+ function Environment2(global2) {
+ this.global = global2;
+ this.flags = {};
+ this.flagRegistry = {};
+ this.urlFlags = {};
+ this.populateURLFlags();
+ }
+ Environment2.prototype.setPlatform = function(platformName, platform) {
+ if (this.platform != null) {
+ console.warn("Platform " + this.platformName + " has already been set. " + ("Overwriting the platform with " + platform + "."));
}
- }, n.prototype.getAsync = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r;
- return fe(this, function(i) {
- switch (i.label) {
+ this.platformName = platformName;
+ this.platform = platform;
+ };
+ Environment2.prototype.registerFlag = function(flagName, evaluationFn, setHook) {
+ this.flagRegistry[flagName] = {evaluationFn, setHook};
+ if (this.urlFlags[flagName] != null) {
+ var flagValue = this.urlFlags[flagName];
+ console.warn("Setting feature override from URL " + flagName + ": " + flagValue + ".");
+ this.set(flagName, flagValue);
+ }
+ };
+ Environment2.prototype.getAsync = function(flagName) {
+ return __awaiter(this, void 0, void 0, function() {
+ var _a, _b;
+ return __generator(this, function(_c) {
+ switch (_c.label) {
case 0:
- return t in this.flags ? [2, this.flags[t]] : (e = this.flags, r = t, [4, this.evaluateFlag(t)]);
+ if (flagName in this.flags) {
+ return [2, this.flags[flagName]];
+ }
+ _a = this.flags;
+ _b = flagName;
+ return [4, this.evaluateFlag(flagName)];
case 1:
- return e[r] = i.sent(), [2, this.flags[t]];
+ _a[_b] = _c.sent();
+ return [2, this.flags[flagName]];
}
});
});
- }, n.prototype.get = function(t) {
- if (t in this.flags)
- return this.flags[t];
- var e = this.evaluateFlag(t);
- if (wc(e))
- throw new Error("Flag " + t + " cannot be synchronously evaluated. Please use getAsync() instead.");
- return this.flags[t] = e, this.flags[t];
- }, n.prototype.getNumber = function(t) {
- return this.get(t);
- }, n.prototype.getBool = function(t) {
- return this.get(t);
- }, n.prototype.getFlags = function() {
+ };
+ Environment2.prototype.get = function(flagName) {
+ if (flagName in this.flags) {
+ return this.flags[flagName];
+ }
+ var flagValue = this.evaluateFlag(flagName);
+ if (isPromise(flagValue)) {
+ throw new Error("Flag " + flagName + " cannot be synchronously evaluated. Please use getAsync() instead.");
+ }
+ this.flags[flagName] = flagValue;
+ return this.flags[flagName];
+ };
+ Environment2.prototype.getNumber = function(flagName) {
+ return this.get(flagName);
+ };
+ Environment2.prototype.getBool = function(flagName) {
+ return this.get(flagName);
+ };
+ Environment2.prototype.getFlags = function() {
return this.flags;
- }, Object.defineProperty(n.prototype, "features", {get: function() {
- return this.flags;
- }, enumerable: true, configurable: true}), n.prototype.set = function(t, e) {
- if (this.flagRegistry[t] == null)
- throw new Error("Cannot set flag " + t + " as it has not been registered.");
- this.flags[t] = e, this.flagRegistry[t].setHook != null && this.flagRegistry[t].setHook(e);
- }, n.prototype.evaluateFlag = function(t) {
- if (this.flagRegistry[t] == null)
- throw new Error("Cannot evaluate flag '" + t + "': no evaluation function found.");
- return this.flagRegistry[t].evaluationFn();
- }, n.prototype.setFlags = function(t) {
- this.flags = Object.assign({}, t);
- }, n.prototype.reset = function() {
- this.flags = {}, this.urlFlags = {}, this.populateURLFlags();
- }, n.prototype.populateURLFlags = function() {
- var t = this;
- if (typeof this.global == "undefined" || typeof this.global.location == "undefined" || typeof this.global.location.search == "undefined")
+ };
+ Object.defineProperty(Environment2.prototype, "features", {
+ get: function() {
+ return this.flags;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Environment2.prototype.set = function(flagName, value) {
+ if (this.flagRegistry[flagName] == null) {
+ throw new Error("Cannot set flag " + flagName + " as it has not been registered.");
+ }
+ this.flags[flagName] = value;
+ if (this.flagRegistry[flagName].setHook != null) {
+ this.flagRegistry[flagName].setHook(value);
+ }
+ };
+ Environment2.prototype.evaluateFlag = function(flagName) {
+ if (this.flagRegistry[flagName] == null) {
+ throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found.");
+ }
+ return this.flagRegistry[flagName].evaluationFn();
+ };
+ Environment2.prototype.setFlags = function(flags) {
+ this.flags = Object.assign({}, flags);
+ };
+ Environment2.prototype.reset = function() {
+ this.flags = {};
+ this.urlFlags = {};
+ this.populateURLFlags();
+ };
+ Environment2.prototype.populateURLFlags = function() {
+ var _this = this;
+ if (typeof this.global === "undefined" || typeof this.global.location === "undefined" || typeof this.global.location.search === "undefined") {
return;
- var e = NL(this.global.location.search);
- if (xf in e) {
- var r = e[xf].split(",");
- r.forEach(function(i) {
- var a = i.split(":"), s = a[0], o = a[1];
- t.urlFlags[s] = _L(s, o);
+ }
+ var urlParams = getQueryParams(this.global.location.search);
+ if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
+ var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(",");
+ keyValues.forEach(function(keyValue) {
+ var _a = keyValue.split(":"), key = _a[0], value = _a[1];
+ _this.urlFlags[key] = parseValue(key, value);
});
}
- }, n;
+ };
+ return Environment2;
}();
- function NL(n) {
- var t = {};
- return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function(e) {
- for (var r = [], i = 1; i < arguments.length; i++)
- r[i - 1] = arguments[i];
- return CL(t, r[0], r[1]), r.join("=");
- }), t;
+ function getQueryParams(queryString) {
+ var params = {};
+ queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function(s) {
+ var t = [];
+ for (var _i2 = 1; _i2 < arguments.length; _i2++) {
+ t[_i2 - 1] = arguments[_i2];
+ }
+ decodeParam(params, t[0], t[1]);
+ return t.join("=");
+ });
+ return params;
}
- function CL(n, t, e) {
- n[decodeURIComponent(t)] = decodeURIComponent(e || "");
+ function decodeParam(params, name, value) {
+ params[decodeURIComponent(name)] = decodeURIComponent(value || "");
}
- function _L(n, t) {
- if (t = t.toLowerCase(), t === "true" || t === "false")
- return t === "true";
- if ("" + +t === t)
- return +t;
- throw new Error("Could not parse value flag value " + t + " for flag " + n + ".");
- }
- function Ge() {
- return A.ENV;
- }
- A.ENV = null;
- function RL(n) {
- A.ENV = n;
- }
- var bc;
- function Sf() {
- if (bc == null) {
- var n = void 0;
- if (typeof window != "undefined")
- n = window;
- else if (typeof global != "undefined")
- n = global;
- else if (typeof process != "undefined")
- n = process;
- else if (typeof self != "undefined")
- n = self;
- else
- throw new Error("Could not find a global object");
- bc = n;
+ function parseValue(flagName, value) {
+ value = value.toLowerCase();
+ if (value === "true" || value === "false") {
+ return value === "true";
+ } else if ("" + +value === value) {
+ return +value;
}
- return bc;
+ throw new Error("Could not parse value flag value " + value + " for flag " + flagName + ".");
}
- function OL() {
- var n = Sf();
- return n._tfGlobals == null && (n._tfGlobals = new Map()), n._tfGlobals;
+ function env() {
+ return exports.ENV;
}
- function If(n, t) {
- var e = OL();
- if (e.has(n))
- return e.get(n);
- var r = t();
- return e.set(n, r), e.get(n);
+ exports.ENV = null;
+ function setEnvironmentGlobal(environment) {
+ exports.ENV = environment;
}
- var xc = "Abs", Lc = "Acos", Sc = "Acosh", Cs = "Add", Ic = "AddN", Af = "All", Tf = "Any", Ac = "ArgMax", Tc = "ArgMin", Nc = "Asin", _c = "Asinh", Cc = "Atan", Rc = "Atanh", Oc = "Atan2", Ec = "AvgPool", Nf = "AvgPoolBackprop", Dc = "AvgPool3D", _f = "AvgPool3DBackprop", kc = "BatchMatMul", Fc = "BatchToSpaceND", Wc = "BroadcastTo", Rs = "Cast", Uc = "Ceil", Bc = "ClipByValue", Cf = "Complex", zc = "Concat", Pc = "Conv2D", Rf = "Conv2DBackpropFilter", Mc = "Conv2DBackpropInput", Hc = "Conv3D", Of = "Conv3DBackpropFilterV2", Ef = "Conv3DBackpropInputV2", Vc = "Cos", Gc = "Cosh", qc = "Cumsum", Df = "CropAndResize", kf = "DepthToSpace", Yc = "DepthwiseConv2dNative", Ff = "DepthwiseConv2dNativeBackpropFilter", Wf = "DepthwiseConv2dNativeBackpropInput", Uf = "Diag", Kc = "Dilation2D", Bf = "Dilation2DBackpropInput", zf = "Dilation2DBackpropFilter", jc = "Div", $c = "Elu", Pf = "EluGrad", Xc = "Erf", Mf = "Equal", Jc = "Exp", Zc = "Expm1", Hf = "FFT", Vf = "Fill", Gf = "FlipLeftRight", Qc = "Floor", el = "FloorDiv", tl = "FusedBatchNorm", nl = "GatherV2", qf = "GatherNd", Yf = "Greater", rl = "GreaterEqual", il = "Identity", Kf = "IFFT", jf = "Imag", al = "IsFinite", sl = "IsInf", ol = "IsNan", $f = "Less", Xf = "LessEqual", Jf = "LinSpace", cl = "Log", ll = "Log1p", Zf = "LogicalAnd", Qf = "LogicalNot", em = "LogicalOr", ul = "LogSoftmax", hl = "LRN", tm = "LRNBackprop", dl = "Max", pl = "Maximum", fl = "MaxPool", nm = "MaxPoolBackprop", ml = "MaxPool3D", rm = "MaxPool3DBackprop", im = "MaxPoolWithArgmax", am = "Mean", gl = "Min", yl = "Minimum", vl = "MirrorPad", wl = "Mod", bl = "Multiply", xl = "Negate", sm = "NotEqual", om = "NonMaxSuppressionV3", cm = "NonMaxSuppressionV4", lm = "NonMaxSuppressionV5", Ll = "OnesLike", Sl = "OneHot", Il = "PadV2", EL = "Pool", Al = "Pow", Tl = "Prelu", um = "Prod", hm = "Range", dm = "Real", Nl = "Reciprocal", _l = "Relu", Cl = "Reshape", Rl = "ResizeNearestNeighbor", pm = "ResizeNearestNeighborGrad", Ol = "ResizeBilinear", fm = "ResizeBilinearGrad", El = "Relu6", Dl = "Reverse", kl = "Round", Fl = "Rsqrt", mm = "ScatterNd", Wl = "SelectV2", Ul = "Selu", Bl = "Slice", zl = "Sin", Pl = "Sinh", Ml = "Sign", Hl = "Sigmoid", Vl = "Softplus", Gl = "Sqrt", ql = "Sum", Yl = "SpaceToBatchND", Kl = "SplitV", jl = "Softmax", $l = "SquaredDifference", gm = "Square", Xl = "Sub", ym = "SparseToDense", vm = "StridedSlice", Jl = "Tan", Zl = "Tanh", Ql = "Tile", wm = "TopK", eu = "Transpose", bm = "Unique", tu = "Unpack", nu = "UnsortedSegmentSum", ru = "ZerosLike", iu = "Step", au = "FromPixels", xm = "RotateWithOffset", su = "_FusedMatMul", ou = "FusedConv2D", cu = "FusedDepthwiseConv2D";
- var Si = If("kernelRegistry", function() {
- return new Map();
- }), va = If("gradRegistry", function() {
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var globalNameSpace;
+ function getGlobalNamespace() {
+ if (globalNameSpace == null) {
+ var ns = void 0;
+ if (typeof window !== "undefined") {
+ ns = window;
+ } else if (typeof global !== "undefined") {
+ ns = global;
+ } else if (typeof process !== "undefined") {
+ ns = process;
+ } else if (typeof self !== "undefined") {
+ ns = self;
+ } else {
+ throw new Error("Could not find a global object");
+ }
+ globalNameSpace = ns;
+ }
+ return globalNameSpace;
+ }
+ function getGlobalMap() {
+ var ns = getGlobalNamespace();
+ if (ns._tfGlobals == null) {
+ ns._tfGlobals = new Map();
+ }
+ return ns._tfGlobals;
+ }
+ function getGlobal(key, init) {
+ var globalMap = getGlobalMap();
+ if (globalMap.has(key)) {
+ return globalMap.get(key);
+ } else {
+ var singleton = init();
+ globalMap.set(key, singleton);
+ return globalMap.get(key);
+ }
+ }
+ var Abs = "Abs";
+ var Acos = "Acos";
+ var Acosh = "Acosh";
+ var Add = "Add";
+ var AddN = "AddN";
+ var All = "All";
+ var Any = "Any";
+ var ArgMax = "ArgMax";
+ var ArgMin = "ArgMin";
+ var Asin = "Asin";
+ var Asinh = "Asinh";
+ var Atan = "Atan";
+ var Atanh = "Atanh";
+ var Atan2 = "Atan2";
+ var AvgPool = "AvgPool";
+ var AvgPoolBackprop = "AvgPoolBackprop";
+ var AvgPool3D = "AvgPool3D";
+ var AvgPool3DBackprop = "AvgPool3DBackprop";
+ var BatchMatMul = "BatchMatMul";
+ var BatchToSpaceND = "BatchToSpaceND";
+ var BroadcastTo = "BroadcastTo";
+ var Cast = "Cast";
+ var Ceil = "Ceil";
+ var ClipByValue = "ClipByValue";
+ var Complex = "Complex";
+ var Concat = "Concat";
+ var Conv2D = "Conv2D";
+ var Conv2DBackpropFilter = "Conv2DBackpropFilter";
+ var Conv2DBackpropInput = "Conv2DBackpropInput";
+ var Conv3D = "Conv3D";
+ var Conv3DBackpropFilterV2 = "Conv3DBackpropFilterV2";
+ var Conv3DBackpropInputV2 = "Conv3DBackpropInputV2";
+ var Cos = "Cos";
+ var Cosh = "Cosh";
+ var Cumsum = "Cumsum";
+ var CropAndResize = "CropAndResize";
+ var DepthToSpace = "DepthToSpace";
+ var DepthwiseConv2dNative = "DepthwiseConv2dNative";
+ var DepthwiseConv2dNativeBackpropFilter = "DepthwiseConv2dNativeBackpropFilter";
+ var DepthwiseConv2dNativeBackpropInput = "DepthwiseConv2dNativeBackpropInput";
+ var Diag = "Diag";
+ var Dilation2D = "Dilation2D";
+ var Dilation2DBackpropInput = "Dilation2DBackpropInput";
+ var Dilation2DBackpropFilter = "Dilation2DBackpropFilter";
+ var Div = "Div";
+ var Elu = "Elu";
+ var EluGrad = "EluGrad";
+ var Erf = "Erf";
+ var Equal = "Equal";
+ var Exp = "Exp";
+ var Expm1 = "Expm1";
+ var FFT = "FFT";
+ var Fill = "Fill";
+ var FlipLeftRight = "FlipLeftRight";
+ var Floor = "Floor";
+ var FloorDiv = "FloorDiv";
+ var FusedBatchNorm = "FusedBatchNorm";
+ var GatherV2 = "GatherV2";
+ var GatherNd = "GatherNd";
+ var Greater = "Greater";
+ var GreaterEqual = "GreaterEqual";
+ var Identity = "Identity";
+ var IFFT = "IFFT";
+ var Imag = "Imag";
+ var IsFinite = "IsFinite";
+ var IsInf = "IsInf";
+ var IsNan = "IsNan";
+ var Less = "Less";
+ var LessEqual = "LessEqual";
+ var LinSpace = "LinSpace";
+ var Log = "Log";
+ var Log1p = "Log1p";
+ var LogicalAnd = "LogicalAnd";
+ var LogicalNot = "LogicalNot";
+ var LogicalOr = "LogicalOr";
+ var LogSoftmax = "LogSoftmax";
+ var LRN = "LRN";
+ var LRNBackprop = "LRNBackprop";
+ var Max = "Max";
+ var Maximum = "Maximum";
+ var MaxPool = "MaxPool";
+ var MaxPoolBackprop = "MaxPoolBackprop";
+ var MaxPool3D = "MaxPool3D";
+ var MaxPool3DBackprop = "MaxPool3DBackprop";
+ var MaxPoolWithArgmax = "MaxPoolWithArgmax";
+ var Mean = "Mean";
+ var Min = "Min";
+ var Minimum = "Minimum";
+ var MirrorPad = "MirrorPad";
+ var Mod = "Mod";
+ var Multiply = "Multiply";
+ var Negate = "Negate";
+ var NotEqual = "NotEqual";
+ var NonMaxSuppressionV3 = "NonMaxSuppressionV3";
+ var NonMaxSuppressionV4 = "NonMaxSuppressionV4";
+ var NonMaxSuppressionV5 = "NonMaxSuppressionV5";
+ var OnesLike = "OnesLike";
+ var OneHot = "OneHot";
+ var PadV2 = "PadV2";
+ var Pool = "Pool";
+ var Pow = "Pow";
+ var Prelu = "Prelu";
+ var Prod = "Prod";
+ var Range = "Range";
+ var Real = "Real";
+ var Reciprocal = "Reciprocal";
+ var Relu = "Relu";
+ var Reshape = "Reshape";
+ var ResizeNearestNeighbor = "ResizeNearestNeighbor";
+ var ResizeNearestNeighborGrad = "ResizeNearestNeighborGrad";
+ var ResizeBilinear = "ResizeBilinear";
+ var ResizeBilinearGrad = "ResizeBilinearGrad";
+ var Relu6 = "Relu6";
+ var Reverse = "Reverse";
+ var Round = "Round";
+ var Rsqrt = "Rsqrt";
+ var ScatterNd = "ScatterNd";
+ var SelectV2 = "SelectV2";
+ var Selu = "Selu";
+ var Slice = "Slice";
+ var Sin = "Sin";
+ var Sinh = "Sinh";
+ var Sign = "Sign";
+ var Sigmoid = "Sigmoid";
+ var Softplus = "Softplus";
+ var Sqrt = "Sqrt";
+ var Sum = "Sum";
+ var SpaceToBatchND = "SpaceToBatchND";
+ var SplitV = "SplitV";
+ var Softmax = "Softmax";
+ var SquaredDifference = "SquaredDifference";
+ var Square = "Square";
+ var Sub = "Sub";
+ var SparseToDense = "SparseToDense";
+ var StridedSlice = "StridedSlice";
+ var Tan = "Tan";
+ var Tanh = "Tanh";
+ var Tile = "Tile";
+ var TopK = "TopK";
+ var Transpose = "Transpose";
+ var Unique = "Unique";
+ var Unpack = "Unpack";
+ var UnsortedSegmentSum = "UnsortedSegmentSum";
+ var ZerosLike = "ZerosLike";
+ var Step = "Step";
+ var FromPixels = "FromPixels";
+ var RotateWithOffset = "RotateWithOffset";
+ var _FusedMatMul = "_FusedMatMul";
+ var FusedConv2D = "FusedConv2D";
+ var FusedDepthwiseConv2D = "FusedDepthwiseConv2D";
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var kernelRegistry = getGlobal("kernelRegistry", function() {
return new Map();
});
- function uu(n, t) {
- var e = lu(n, t);
- return Si.get(e);
+ var gradRegistry = getGlobal("gradRegistry", function() {
+ return new Map();
+ });
+ function getKernel(kernelName, backendName) {
+ var key = makeKey(kernelName, backendName);
+ return kernelRegistry.get(key);
}
- function hu(n) {
- return va.get(n);
+ function getGradient(kernelName) {
+ return gradRegistry.get(kernelName);
}
- function Os(n) {
- for (var t = Si.entries(), e = []; ; ) {
- var r = t.next(), i = r.done, a = r.value;
- if (i)
+ function getKernelsForBackend(backendName) {
+ var it = kernelRegistry.entries();
+ var result = [];
+ while (true) {
+ var _a = it.next(), done = _a.done, value = _a.value;
+ if (done) {
break;
- var s = a[0], o = a[1], c = s.split("_")[0];
- c === n && e.push(o);
+ }
+ var key = value[0], config = value[1];
+ var backend2 = key.split("_")[0];
+ if (backend2 === backendName) {
+ result.push(config);
+ }
}
- return e;
+ return result;
}
- function Lm(n) {
- var t = n.kernelName, e = n.backendName, r = lu(t, e);
- Si.has(r) && console.warn("The kernel '" + t + "' for backend " + ("'" + e + "' is already registered")), Si.set(r, n);
+ function registerKernel(config) {
+ var kernelName = config.kernelName, backendName = config.backendName;
+ var key = makeKey(kernelName, backendName);
+ if (kernelRegistry.has(key)) {
+ console.warn("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is already registered"));
+ }
+ kernelRegistry.set(key, config);
}
- function Sm(n) {
- var t = n.kernelName;
- va.has(t) && (Ge().getBool("DEBUG") && console.warn("Overriding the gradient for '" + t + "'")), va.set(t, n);
+ function registerGradient(config) {
+ var kernelName = config.kernelName;
+ if (gradRegistry.has(kernelName)) {
+ if (env().getBool("DEBUG")) {
+ console.warn("Overriding the gradient for '" + kernelName + "'");
+ }
+ }
+ gradRegistry.set(kernelName, config);
}
- function DL(n, t) {
- var e = lu(n, t);
- if (!Si.has(e))
- throw new Error("The kernel '" + n + "' for backend " + ("'" + t + "' is not registered"));
- Si.delete(e);
+ function unregisterKernel(kernelName, backendName) {
+ var key = makeKey(kernelName, backendName);
+ if (!kernelRegistry.has(key)) {
+ throw new Error("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is not registered"));
+ }
+ kernelRegistry.delete(key);
}
- function kL(n) {
- if (!va.has(n))
- throw new Error("The gradient '" + n + "' for backend is not registered");
- va.delete(n);
+ function unregisterGradient(kernelName) {
+ if (!gradRegistry.has(kernelName)) {
+ throw new Error("The gradient '" + kernelName + "' for backend is not registered");
+ }
+ gradRegistry.delete(kernelName);
}
- function FL(n, t) {
- var e = Os(n);
- e.forEach(function(r) {
- var i = Object.assign({}, r, {backendName: t});
- Lm(i);
+ function copyRegisteredKernels(registeredBackendName, newBackendName) {
+ var kernels = getKernelsForBackend(registeredBackendName);
+ kernels.forEach(function(kernelConfig) {
+ var newKernelConfig = Object.assign({}, kernelConfig, {backendName: newBackendName});
+ registerKernel(newKernelConfig);
});
}
- function lu(n, t) {
- return t + "_" + n;
+ function makeKey(kernelName, backendName) {
+ return backendName + "_" + kernelName;
}
- function WL(n, t) {
- return t === "string" ? du(n) : Es([n], t);
- }
- function UL(n, t) {
- return n instanceof Float32Array && t === "float32" || n instanceof Int32Array && t === "int32" || n instanceof Uint8Array && t === "bool";
- }
- function Es(n, t) {
- if (t === "string")
- throw new Error("Cannot convert a string[] to a TypedArray");
- if (Array.isArray(n) && (n = Br(n)), Ge().getBool("DEBUG") && pf(n, t), UL(n, t))
- return n;
- if (t == null || t === "float32" || t === "complex64")
- return new Float32Array(n);
- if (t === "int32")
- return new Int32Array(n);
- if (t === "bool") {
- for (var e = new Uint8Array(n.length), r = 0; r < e.length; ++r)
- Math.round(n[r]) !== 0 && (e[r] = 1);
- return e;
- } else
- throw new Error("Unknown data type " + t);
- }
- function pu() {
- return Ge().platform.now();
- }
- function BL(n, t) {
- return Ge().platform.fetch(n, t);
- }
- function du(n, t) {
- return t === void 0 && (t = "utf-8"), t = t || "utf-8", Ge().platform.encode(n, t);
- }
- function fu(n, t) {
- return t === void 0 && (t = "utf-8"), t = t || "utf-8", Ge().platform.decode(n, t);
- }
- var zL = {__proto__: null, createScalarValue: WL, toTypedArray: Es, now: pu, fetch: BL, encodeString: du, decodeString: fu, shuffle: lf, clamp: ga, nearestLargerEven: mL, sum: gL, randUniform: yL, distSquared: vL, assert: E, assertShapesMatch: Pe, assertNonNull: Ur, flatten: Br, sizeFromShape: pt, isScalarShape: wL, arraysEqual: pn, isInt: ot, tanh: bL, sizeToSquarishShape: xL, createShuffledIndices: LL, rightPad: ya, repeatedTry: SL, inferFromImplicitShape: uf, parseAxisParam: rt, squeezeShape: hf, getTypedArrayFromDType: Ts, getArrayFromDType: df, checkConversionForErrors: pf, isValidDtype: ff, hasEncodingLoss: mf, isTypedArray: Ft, bytesPerElement: gf, bytesFromStringArray: yf, isString: or, isBoolean: vf, isNumber: wf, inferDtype: Ns, isFunction: cr, nearestDivisor: _s, computeStrides: bi, toNestedArray: xi, makeOnesTypedArray: yc, makeZerosTypedArray: Li, makeZerosNestedTypedArray: IL, assertNonNegativeIntegerDimensions: vc, locToIndex: AL, indexToLoc: TL, isPromise: wc};
- var HL = function() {
- function n(t, e) {
- this.backendTimer = t, this.logger = e, e == null && (this.logger = new ML());
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function createScalarValue(value, dtype) {
+ if (dtype === "string") {
+ return encodeString(value);
}
- return n.prototype.profileKernel = function(t, e, r) {
- for (var i, a = function() {
- i = r();
- }, s = this.backendTimer.time(a), o = function(u) {
- var h = i[u];
- h.data().then(function(d) {
- PL(d, h.dtype, t);
+ return toTypedArray([value], dtype);
+ }
+ function noConversionNeeded(a, dtype) {
+ return a instanceof Float32Array && dtype === "float32" || a instanceof Int32Array && dtype === "int32" || a instanceof Uint8Array && dtype === "bool";
+ }
+ function toTypedArray(a, dtype) {
+ if (dtype === "string") {
+ throw new Error("Cannot convert a string[] to a TypedArray");
+ }
+ if (Array.isArray(a)) {
+ a = flatten(a);
+ }
+ if (env().getBool("DEBUG")) {
+ checkConversionForErrors(a, dtype);
+ }
+ if (noConversionNeeded(a, dtype)) {
+ return a;
+ }
+ if (dtype == null || dtype === "float32" || dtype === "complex64") {
+ return new Float32Array(a);
+ } else if (dtype === "int32") {
+ return new Int32Array(a);
+ } else if (dtype === "bool") {
+ var bool = new Uint8Array(a.length);
+ for (var i = 0; i < bool.length; ++i) {
+ if (Math.round(a[i]) !== 0) {
+ bool[i] = 1;
+ }
+ }
+ return bool;
+ } else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ }
+ function now2() {
+ return env().platform.now();
+ }
+ function fetch$1(path, requestInits) {
+ return env().platform.fetch(path, requestInits);
+ }
+ function encodeString(s, encoding) {
+ if (encoding === void 0) {
+ encoding = "utf-8";
+ }
+ encoding = encoding || "utf-8";
+ return env().platform.encode(s, encoding);
+ }
+ function decodeString(bytes, encoding) {
+ if (encoding === void 0) {
+ encoding = "utf-8";
+ }
+ encoding = encoding || "utf-8";
+ return env().platform.decode(bytes, encoding);
+ }
+ var util = {
+ __proto__: null,
+ createScalarValue,
+ toTypedArray,
+ now: now2,
+ fetch: fetch$1,
+ encodeString,
+ decodeString,
+ shuffle,
+ clamp,
+ nearestLargerEven,
+ sum,
+ randUniform,
+ distSquared,
+ assert,
+ assertShapesMatch,
+ assertNonNull,
+ flatten,
+ sizeFromShape,
+ isScalarShape,
+ arraysEqual,
+ isInt,
+ tanh,
+ sizeToSquarishShape,
+ createShuffledIndices,
+ rightPad,
+ repeatedTry,
+ inferFromImplicitShape,
+ parseAxisParam,
+ squeezeShape,
+ getTypedArrayFromDType,
+ getArrayFromDType,
+ checkConversionForErrors,
+ isValidDtype,
+ hasEncodingLoss,
+ isTypedArray,
+ bytesPerElement,
+ bytesFromStringArray,
+ isString,
+ isBoolean,
+ isNumber,
+ inferDtype,
+ isFunction,
+ nearestDivisor,
+ computeStrides,
+ toNestedArray,
+ makeOnesTypedArray,
+ makeZerosTypedArray,
+ makeZerosNestedTypedArray,
+ assertNonNegativeIntegerDimensions,
+ locToIndex,
+ indexToLoc,
+ isPromise
+ };
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var Profiler = function() {
+ function Profiler2(backendTimer, logger) {
+ this.backendTimer = backendTimer;
+ this.logger = logger;
+ if (logger == null) {
+ this.logger = new Logger();
+ }
+ }
+ Profiler2.prototype.profileKernel = function(kernelName, inputs, f) {
+ var outputs;
+ var holdResultWrapperFn = function() {
+ outputs = f();
+ };
+ var timer = this.backendTimer.time(holdResultWrapperFn);
+ var _loop_1 = function(i2) {
+ var output = outputs[i2];
+ output.data().then(function(tensorVals) {
+ checkComputationForErrors(tensorVals, output.dtype, kernelName);
});
- }, c = 0; c < i.length; c++)
- o(c);
- var l = {kernelName: t, outputs: i, inputs: e, timeMs: s.then(function(u) {
- return u.kernelMs;
- }), extraInfo: s.then(function(u) {
- return u.getExtraProfileInfo != null ? u.getExtraProfileInfo() : "";
- })};
- return l;
- }, n.prototype.logKernelProfile = function(t) {
- var e = this, r = t.kernelName, i = t.outputs, a = t.timeMs, s = t.inputs, o = t.extraInfo;
- i.forEach(function(c) {
- Promise.all([c.data(), a, o]).then(function(l) {
- e.logger.logKernelProfile(r, c, l[0], l[1], s, l[2]);
+ };
+ for (var i = 0; i < outputs.length; i++) {
+ _loop_1(i);
+ }
+ var kernelProfile = {
+ kernelName,
+ outputs,
+ inputs,
+ timeMs: timer.then(function(timing) {
+ return timing.kernelMs;
+ }),
+ extraInfo: timer.then(function(timing) {
+ return timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : "";
+ })
+ };
+ return kernelProfile;
+ };
+ Profiler2.prototype.logKernelProfile = function(kernelProfile) {
+ var _this = this;
+ var kernelName = kernelProfile.kernelName, outputs = kernelProfile.outputs, timeMs = kernelProfile.timeMs, inputs = kernelProfile.inputs, extraInfo = kernelProfile.extraInfo;
+ outputs.forEach(function(result) {
+ Promise.all([result.data(), timeMs, extraInfo]).then(function(valueContainer) {
+ _this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
});
});
- }, n;
+ };
+ return Profiler2;
}();
- function PL(n, t, e) {
- if (t !== "float32")
+ function checkComputationForErrors(vals, dtype, kernelName) {
+ if (dtype !== "float32") {
return false;
- for (var r = 0; r < n.length; r++) {
- var i = n[r];
- if (isNaN(i) || !isFinite(i))
- return console.warn("Found " + i + " in the result of '" + e + "'"), true;
+ }
+ for (var i = 0; i < vals.length; i++) {
+ var num = vals[i];
+ if (isNaN(num) || !isFinite(num)) {
+ console.warn("Found " + num + " in the result of '" + kernelName + "'");
+ return true;
+ }
}
return false;
}
- var ML = function() {
- function n() {
+ var Logger = function() {
+ function Logger2() {
}
- return n.prototype.logKernelProfile = function(t, e, r, i, a, s) {
- var o = typeof i == "number" ? ya(i + "ms", 9) : i.error, c = ya(t, 25), l = e.rank, u = e.size, h = ya(e.shape.toString(), 14), d = "";
- for (var p in a) {
- var f = a[p];
- if (f != null) {
- var m = f.shape || e.shape, g = m.length;
- d += p + ": " + g + "D " + (g > 0 ? m : "") + " ";
+ Logger2.prototype.logKernelProfile = function(name, result, vals, timeMs, inputs, extraInfo) {
+ var time2 = typeof timeMs === "number" ? rightPad(timeMs + "ms", 9) : timeMs["error"];
+ var paddedName = rightPad(name, 25);
+ var rank = result.rank;
+ var size = result.size;
+ var shape = rightPad(result.shape.toString(), 14);
+ var inputShapesDescription = "";
+ for (var name_1 in inputs) {
+ var input = inputs[name_1];
+ if (input != null) {
+ var inputShape = input.shape || result.shape;
+ var inputRank = inputShape.length;
+ inputShapesDescription += name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : "") + " ";
}
}
- console.log("%c" + c + " %c" + o + " %c" + l + "D " + h + " %c" + u + " %c" + d + " %c" + s, "font-weight:bold", "color:red", "color:blue", "color: orange", "color: green", "color: steelblue");
- }, n;
+ console.log("%c" + paddedName + " %c" + time2 + " %c" + rank + "D " + shape + " %c" + size + " %c" + inputShapesDescription + " %c" + extraInfo, "font-weight:bold", "color:red", "color:blue", "color: orange", "color: green", "color: steelblue");
+ };
+ return Logger2;
}();
- function VL(n, t, e) {
- for (var r = {}, i = {}, a = 0; a < t.length; a++)
- r[t[a].id] = true;
- for (var a = 0; a < n.length; a++) {
- var s = n[a], o = s.inputs;
- for (var c in o) {
- for (var l = o[c], u = false, h = 0; h < t.length; h++)
- if (r[l.id]) {
- s.outputs.forEach(function(b) {
- return r[b.id] = true;
- }), u = true, i[s.id] = true;
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function getFilteredNodesXToY(tape, xs, y) {
+ var tensorsFromX = {};
+ var nodesFromX = {};
+ for (var i = 0; i < xs.length; i++) {
+ tensorsFromX[xs[i].id] = true;
+ }
+ for (var i = 0; i < tape.length; i++) {
+ var node = tape[i];
+ var nodeInputs = node.inputs;
+ for (var inputName in nodeInputs) {
+ var input = nodeInputs[inputName];
+ var anyInputFromX = false;
+ for (var j = 0; j < xs.length; j++) {
+ if (tensorsFromX[input.id]) {
+ node.outputs.forEach(function(output) {
+ return tensorsFromX[output.id] = true;
+ });
+ anyInputFromX = true;
+ nodesFromX[node.id] = true;
break;
}
- if (u)
- break;
- }
- }
- var d = {};
- d[e.id] = true;
- for (var p = {}, a = n.length - 1; a >= 0; a--)
- for (var s = n[a], o = s.inputs, h = 0; h < s.outputs.length; h++)
- if (d[s.outputs[h].id]) {
- for (var c in o)
- d[o[c].id] = true, p[s.id] = true;
+ }
+ if (anyInputFromX) {
break;
}
- for (var f = [], a = 0; a < n.length; a++) {
- var s = n[a];
- if (i[s.id] && p[s.id]) {
- var m = {};
- for (var c in s.inputs) {
- var g = s.inputs[c];
- r[g.id] && (m[c] = g);
- }
- var y = Object.assign({}, s);
- y.inputs = m, y.outputs = s.outputs, f.push(y);
}
}
- return f;
+ var tensorsLeadToY = {};
+ tensorsLeadToY[y.id] = true;
+ var nodesToY = {};
+ for (var i = tape.length - 1; i >= 0; i--) {
+ var node = tape[i];
+ var nodeInputs = node.inputs;
+ for (var j = 0; j < node.outputs.length; j++) {
+ if (tensorsLeadToY[node.outputs[j].id]) {
+ for (var inputName in nodeInputs) {
+ tensorsLeadToY[nodeInputs[inputName].id] = true;
+ nodesToY[node.id] = true;
+ }
+ break;
+ }
+ }
+ }
+ var filteredTape = [];
+ for (var i = 0; i < tape.length; i++) {
+ var node = tape[i];
+ if (nodesFromX[node.id] && nodesToY[node.id]) {
+ var prunedInputs = {};
+ for (var inputName in node.inputs) {
+ var nodeInput = node.inputs[inputName];
+ if (tensorsFromX[nodeInput.id]) {
+ prunedInputs[inputName] = nodeInput;
+ }
+ }
+ var prunedNode = Object.assign({}, node);
+ prunedNode.inputs = prunedInputs;
+ prunedNode.outputs = node.outputs;
+ filteredTape.push(prunedNode);
+ }
+ }
+ return filteredTape;
}
- function GL(n, t, e, r) {
- for (var i = function(s) {
- var o = t[s], c = [];
- if (o.outputs.forEach(function(d) {
- var p = n[d.id];
- p != null ? c.push(p) : c.push(null);
- }), o.gradient == null)
- throw new Error("Cannot compute gradient: gradient function not found " + ("for " + o.kernelName + "."));
- var l = o.gradient(c), u = function(d) {
- if (!(d in l))
- throw new Error("Cannot backprop through input " + d + ". " + ("Available gradients found: " + Object.keys(l) + "."));
- var p = e(function() {
- return l[d]();
+ function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy2, add2) {
+ var _loop_1 = function(i2) {
+ var node = filteredTape[i2];
+ var dys = [];
+ node.outputs.forEach(function(o) {
+ var gradTensor = tensorAccumulatedGradientMap[o.id];
+ if (gradTensor != null) {
+ dys.push(gradTensor);
+ } else {
+ dys.push(null);
+ }
+ });
+ if (node.gradient == null) {
+ throw new Error("Cannot compute gradient: gradient function not found " + ("for " + node.kernelName + "."));
+ }
+ var inputGradients = node.gradient(dys);
+ var _loop_2 = function(inputName2) {
+ if (!(inputName2 in inputGradients)) {
+ throw new Error("Cannot backprop through input " + inputName2 + ". " + ("Available gradients found: " + Object.keys(inputGradients) + "."));
+ }
+ var dx = tidy2(function() {
+ return inputGradients[inputName2]();
});
- if (p.dtype !== "float32")
- throw new Error("Error in gradient for op " + o.kernelName + ". The gradient of input " + (d + " must have 'float32' dtype, but has '" + p.dtype + "'"));
- var f = o.inputs[d];
- if (!pn(p.shape, f.shape))
- throw new Error("Error in gradient for op " + o.kernelName + ". The gradient of input " + ("'" + d + "' has shape '" + p.shape + "', which does not match ") + ("the shape of the input '" + f.shape + "'"));
- if (n[f.id] == null)
- n[f.id] = p;
- else {
- var m = n[f.id];
- n[f.id] = r(m, p), m.dispose();
+ if (dx.dtype !== "float32") {
+ throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + (inputName2 + " must have 'float32' dtype, but has '" + dx.dtype + "'"));
+ }
+ var x = node.inputs[inputName2];
+ if (!arraysEqual(dx.shape, x.shape)) {
+ throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + ("'" + inputName2 + "' has shape '" + dx.shape + "', which does not match ") + ("the shape of the input '" + x.shape + "'"));
+ }
+ if (tensorAccumulatedGradientMap[x.id] == null) {
+ tensorAccumulatedGradientMap[x.id] = dx;
+ } else {
+ var curGradient = tensorAccumulatedGradientMap[x.id];
+ tensorAccumulatedGradientMap[x.id] = add2(curGradient, dx);
+ curGradient.dispose();
}
};
- for (var h in o.inputs)
- u(h);
- }, a = t.length - 1; a >= 0; a--)
- i(a);
+ for (var inputName in node.inputs) {
+ _loop_2(inputName);
+ }
+ };
+ for (var i = filteredTape.length - 1; i >= 0; i--) {
+ _loop_1(i);
+ }
}
- var Im = 20, wa = 3, mu = 7;
- function YL(n, t, e, r) {
- var i = bi(t), a = qL(n, t, e, i), s = t.length, o = Ds(n, t, e, i, a), c = ["Tensor"];
- return r && (c.push(" dtype: " + e), c.push(" rank: " + s), c.push(" shape: [" + t + "]"), c.push(" values:")), c.push(o.map(function(l) {
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var FORMAT_LIMIT_NUM_VALS = 20;
+ var FORMAT_NUM_FIRST_LAST_VALS = 3;
+ var FORMAT_NUM_SIG_DIGITS = 7;
+ function tensorToString(vals, shape, dtype, verbose) {
+ var strides = computeStrides(shape);
+ var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
+ var rank = shape.length;
+ var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
+ var lines = ["Tensor"];
+ if (verbose) {
+ lines.push(" dtype: " + dtype);
+ lines.push(" rank: " + rank);
+ lines.push(" shape: [" + shape + "]");
+ lines.push(" values:");
+ }
+ lines.push(valsLines.map(function(l) {
return " " + l;
- }).join(`
-`)), c.join(`
-`);
+ }).join("\n"));
+ return lines.join("\n");
}
- function qL(n, t, e, r) {
- var i = pt(t), a = r[r.length - 1], s = new Array(a).fill(0), o = t.length, c = e === "complex64" ? xa(n) : n;
- if (o > 1)
- for (var l = 0; l < i / a; l++)
- for (var u = l * a, h = 0; h < a; h++)
- s[h] = Math.max(s[h], ba(c[u + h], 0, e).length);
- return s;
- }
- function ba(n, t, e) {
- var r;
- return Array.isArray(n) ? r = parseFloat(n[0].toFixed(mu)) + " + " + (parseFloat(n[1].toFixed(mu)) + "j") : or(n) ? r = "'" + n + "'" : e === "bool" ? r = Am(n) : r = parseFloat(n.toFixed(mu)).toString(), ya(r, t);
- }
- function Am(n) {
- return n === 0 ? "false" : "true";
- }
- function Ds(n, t, e, r, i, a) {
- a === void 0 && (a = true);
- var s = e === "complex64" ? 2 : 1, o = t[0], c = t.length;
- if (c === 0) {
- if (e === "complex64") {
- var l = xa(n);
- return [ba(l[0], 0, e)];
- }
- return e === "bool" ? [Am(n[0])] : [n[0].toString()];
- }
- if (c === 1) {
- if (o > Im) {
- var u = wa * s, h = Array.from(n.slice(0, u)), d = Array.from(n.slice((o - wa) * s, o * s));
- return e === "complex64" && (h = xa(h), d = xa(d)), ["[" + h.map(function(I, C) {
- return ba(I, i[C], e);
- }).join(", ") + ", ..., " + d.map(function(I, C) {
- return ba(I, i[o - wa + C], e);
- }).join(", ") + "]"];
- }
- var p = e === "complex64" ? xa(n) : Array.from(n);
- return ["[" + p.map(function(I, C) {
- return ba(I, i[C], e);
- }).join(", ") + "]"];
- }
- var f = t.slice(1), m = r.slice(1), g = r[0] * s, y = [];
- if (o > Im) {
- for (var w = 0; w < wa; w++) {
- var b = w * g, L = b + g;
- y.push.apply(y, Ds(n.slice(b, L), f, e, m, i, false));
- }
- y.push("...");
- for (var w = o - wa; w < o; w++) {
- var b = w * g, L = b + g;
- y.push.apply(y, Ds(n.slice(b, L), f, e, m, i, w === o - 1));
- }
- } else
- for (var w = 0; w < o; w++) {
- var b = w * g, L = b + g;
- y.push.apply(y, Ds(n.slice(b, L), f, e, m, i, w === o - 1));
- }
- var x = c === 2 ? "," : "";
- y[0] = "[" + y[0] + x;
- for (var w = 1; w < y.length - 1; w++)
- y[w] = " " + y[w] + x;
- for (var N = `,
-`, w = 2; w < c; w++)
- N += `
-`;
- return y[y.length - 1] = " " + y[y.length - 1] + "]" + (a ? "" : N), y;
- }
- function xa(n) {
- for (var t = [], e = 0; e < n.length; e += 2)
- t.push([n[e], n[e + 1]]);
- return t;
- }
- var ks = function() {
- function n(t, e, r) {
- var i = this;
- if (this.dtype = e, this.shape = t.slice(), this.size = pt(t), r != null) {
- var a = r.length;
- E(a === this.size, function() {
- return "Length of values '" + a + "' does not match the size " + ("inferred by the shape '" + i.size + "'.");
- });
- }
- if (e === "complex64")
- throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).");
- this.values = r || df(e, this.size), this.strides = bi(t);
- }
- return n.prototype.set = function(t) {
- for (var e = this, r = [], i = 1; i < arguments.length; i++)
- r[i - 1] = arguments[i];
- r.length === 0 && (r = [0]), E(r.length === this.rank, function() {
- return "The number of provided coordinates (" + r.length + ") must " + ("match the rank (" + e.rank + ")");
- });
- var a = this.locToIndex(r);
- this.values[a] = t;
- }, n.prototype.get = function() {
- for (var t = [], e = 0; e < arguments.length; e++)
- t[e] = arguments[e];
- t.length === 0 && (t = [0]);
- for (var r = 0, i = 0, a = t; i < a.length; i++) {
- var s = a[i];
- if (s < 0 || s >= this.shape[r]) {
- var o = "Requested out of range element at " + t + ". " + (" Buffer shape=" + this.shape);
- throw new Error(o);
+ function computeMaxSizePerColumn(vals, shape, dtype, strides) {
+ var n = sizeFromShape(shape);
+ var numCols = strides[strides.length - 1];
+ var padPerCol = new Array(numCols).fill(0);
+ var rank = shape.length;
+ var valuesOrTuples = dtype === "complex64" ? createComplexTuples(vals) : vals;
+ if (rank > 1) {
+ for (var row = 0; row < n / numCols; row++) {
+ var offset = row * numCols;
+ for (var j = 0; j < numCols; j++) {
+ padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
}
- r++;
}
- for (var c = t[t.length - 1], l = 0; l < t.length - 1; ++l)
- c += this.strides[l] * t[l];
- return this.values[c];
- }, n.prototype.locToIndex = function(t) {
- if (this.rank === 0)
- return 0;
- if (this.rank === 1)
- return t[0];
- for (var e = t[t.length - 1], r = 0; r < t.length - 1; ++r)
- e += this.strides[r] * t[r];
- return e;
- }, n.prototype.indexToLoc = function(t) {
- if (this.rank === 0)
- return [];
- if (this.rank === 1)
- return [t];
- for (var e = new Array(this.shape.length), r = 0; r < e.length - 1; ++r)
- e[r] = Math.floor(t / this.strides[r]), t -= e[r] * this.strides[r];
- return e[e.length - 1] = t, e;
- }, Object.defineProperty(n.prototype, "rank", {get: function() {
- return this.shape.length;
- }, enumerable: true, configurable: true}), n.prototype.toTensor = function() {
- return Cn().makeTensor(this.values, this.shape, this.dtype);
- }, n;
- }(), Cn = null, Ii = null;
- function KL(n) {
- Cn = n;
- }
- function jL(n) {
- Ii = n;
- }
- var K = function() {
- function n(t, e, r, i) {
- this.kept = false, this.isDisposedInternal = false, this.shape = t.slice(), this.dtype = e || "float32", this.size = pt(t), this.strides = bi(t), this.dataId = r, this.id = i, this.rankType = this.rank < 5 ? this.rank.toString() : "higher";
}
- return Object.defineProperty(n.prototype, "rank", {get: function() {
- return this.shape.length;
- }, enumerable: true, configurable: true}), n.prototype.buffer = function() {
- return pe(this, void 0, void 0, function() {
- var t;
- return fe(this, function(e) {
- switch (e.label) {
+ return padPerCol;
+ }
+ function valToString(val, pad2, dtype) {
+ var valStr;
+ if (Array.isArray(val)) {
+ valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " + (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j");
+ } else if (isString(val)) {
+ valStr = "'" + val + "'";
+ } else if (dtype === "bool") {
+ valStr = boolNumToString(val);
+ } else {
+ valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
+ }
+ return rightPad(valStr, pad2);
+ }
+ function boolNumToString(v) {
+ return v === 0 ? "false" : "true";
+ }
+ function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) {
+ if (isLast === void 0) {
+ isLast = true;
+ }
+ var storagePerElement = dtype === "complex64" ? 2 : 1;
+ var size = shape[0];
+ var rank = shape.length;
+ if (rank === 0) {
+ if (dtype === "complex64") {
+ var complexTuple = createComplexTuples(vals);
+ return [valToString(complexTuple[0], 0, dtype)];
+ }
+ if (dtype === "bool") {
+ return [boolNumToString(vals[0])];
+ }
+ return [vals[0].toString()];
+ }
+ if (rank === 1) {
+ if (size > FORMAT_LIMIT_NUM_VALS) {
+ var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
+ var firstVals = Array.from(vals.slice(0, firstValsSize));
+ var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
+ if (dtype === "complex64") {
+ firstVals = createComplexTuples(firstVals);
+ lastVals = createComplexTuples(lastVals);
+ }
+ return [
+ "[" + firstVals.map(function(x, i2) {
+ return valToString(x, padPerCol[i2], dtype);
+ }).join(", ") + ", ..., " + lastVals.map(function(x, i2) {
+ return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i2], dtype);
+ }).join(", ") + "]"
+ ];
+ }
+ var displayVals = dtype === "complex64" ? createComplexTuples(vals) : Array.from(vals);
+ return [
+ "[" + displayVals.map(function(x, i2) {
+ return valToString(x, padPerCol[i2], dtype);
+ }).join(", ") + "]"
+ ];
+ }
+ var subshape = shape.slice(1);
+ var substrides = strides.slice(1);
+ var stride = strides[0] * storagePerElement;
+ var lines = [];
+ if (size > FORMAT_LIMIT_NUM_VALS) {
+ for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
+ var start = i * stride;
+ var end = start + stride;
+ lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false));
+ }
+ lines.push("...");
+ for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
+ var start = i * stride;
+ var end = start + stride;
+ lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1));
+ }
+ } else {
+ for (var i = 0; i < size; i++) {
+ var start = i * stride;
+ var end = start + stride;
+ lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1));
+ }
+ }
+ var sep = rank === 2 ? "," : "";
+ lines[0] = "[" + lines[0] + sep;
+ for (var i = 1; i < lines.length - 1; i++) {
+ lines[i] = " " + lines[i] + sep;
+ }
+ var newLineSep = ",\n";
+ for (var i = 2; i < rank; i++) {
+ newLineSep += "\n";
+ }
+ lines[lines.length - 1] = " " + lines[lines.length - 1] + "]" + (isLast ? "" : newLineSep);
+ return lines;
+ }
+ function createComplexTuples(vals) {
+ var complexTuples = [];
+ for (var i = 0; i < vals.length; i += 2) {
+ complexTuples.push([vals[i], vals[i + 1]]);
+ }
+ return complexTuples;
+ }
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var TensorBuffer = function() {
+ function TensorBuffer2(shape, dtype, values) {
+ var _this = this;
+ this.dtype = dtype;
+ this.shape = shape.slice();
+ this.size = sizeFromShape(shape);
+ if (values != null) {
+ var n_1 = values.length;
+ assert(n_1 === this.size, function() {
+ return "Length of values '" + n_1 + "' does not match the size " + ("inferred by the shape '" + _this.size + "'.");
+ });
+ }
+ if (dtype === "complex64") {
+ throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).");
+ }
+ this.values = values || getArrayFromDType(dtype, this.size);
+ this.strides = computeStrides(shape);
+ }
+ TensorBuffer2.prototype.set = function(value) {
+ var _this = this;
+ var locs = [];
+ for (var _i2 = 1; _i2 < arguments.length; _i2++) {
+ locs[_i2 - 1] = arguments[_i2];
+ }
+ if (locs.length === 0) {
+ locs = [0];
+ }
+ assert(locs.length === this.rank, function() {
+ return "The number of provided coordinates (" + locs.length + ") must " + ("match the rank (" + _this.rank + ")");
+ });
+ var index = this.locToIndex(locs);
+ this.values[index] = value;
+ };
+ TensorBuffer2.prototype.get = function() {
+ var locs = [];
+ for (var _i2 = 0; _i2 < arguments.length; _i2++) {
+ locs[_i2] = arguments[_i2];
+ }
+ if (locs.length === 0) {
+ locs = [0];
+ }
+ var i = 0;
+ for (var _a = 0, locs_1 = locs; _a < locs_1.length; _a++) {
+ var loc = locs_1[_a];
+ if (loc < 0 || loc >= this.shape[i]) {
+ var msg = "Requested out of range element at " + locs + ". " + (" Buffer shape=" + this.shape);
+ throw new Error(msg);
+ }
+ i++;
+ }
+ var index = locs[locs.length - 1];
+ for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) {
+ index += this.strides[i_1] * locs[i_1];
+ }
+ return this.values[index];
+ };
+ TensorBuffer2.prototype.locToIndex = function(locs) {
+ if (this.rank === 0) {
+ return 0;
+ } else if (this.rank === 1) {
+ return locs[0];
+ }
+ var index = locs[locs.length - 1];
+ for (var i = 0; i < locs.length - 1; ++i) {
+ index += this.strides[i] * locs[i];
+ }
+ return index;
+ };
+ TensorBuffer2.prototype.indexToLoc = function(index) {
+ if (this.rank === 0) {
+ return [];
+ } else if (this.rank === 1) {
+ return [index];
+ }
+ var locs = new Array(this.shape.length);
+ for (var i = 0; i < locs.length - 1; ++i) {
+ locs[i] = Math.floor(index / this.strides[i]);
+ index -= locs[i] * this.strides[i];
+ }
+ locs[locs.length - 1] = index;
+ return locs;
+ };
+ Object.defineProperty(TensorBuffer2.prototype, "rank", {
+ get: function() {
+ return this.shape.length;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ TensorBuffer2.prototype.toTensor = function() {
+ return trackerFn().makeTensor(this.values, this.shape, this.dtype);
+ };
+ return TensorBuffer2;
+ }();
+ var trackerFn = null;
+ var opHandler = null;
+ function setTensorTracker(fn) {
+ trackerFn = fn;
+ }
+ function setOpHandler(handler) {
+ opHandler = handler;
+ }
+ var Tensor = function() {
+ function Tensor2(shape, dtype, dataId, id) {
+ this.kept = false;
+ this.isDisposedInternal = false;
+ this.shape = shape.slice();
+ this.dtype = dtype || "float32";
+ this.size = sizeFromShape(shape);
+ this.strides = computeStrides(shape);
+ this.dataId = dataId;
+ this.id = id;
+ this.rankType = this.rank < 5 ? this.rank.toString() : "higher";
+ }
+ Object.defineProperty(Tensor2.prototype, "rank", {
+ get: function() {
+ return this.shape.length;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Tensor2.prototype.buffer = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var vals;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
return [4, this.data()];
case 1:
- return t = e.sent(), [2, Ii.buffer(this.shape, this.dtype, t)];
+ vals = _a.sent();
+ return [2, opHandler.buffer(this.shape, this.dtype, vals)];
}
});
});
- }, n.prototype.bufferSync = function() {
- return Ii.buffer(this.shape, this.dtype, this.dataSync());
- }, n.prototype.array = function() {
- return pe(this, void 0, void 0, function() {
- var t;
- return fe(this, function(e) {
- switch (e.label) {
+ };
+ Tensor2.prototype.bufferSync = function() {
+ return opHandler.buffer(this.shape, this.dtype, this.dataSync());
+ };
+ Tensor2.prototype.array = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var vals;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
return [4, this.data()];
case 1:
- return t = e.sent(), [2, xi(this.shape, t)];
+ vals = _a.sent();
+ return [2, toNestedArray(this.shape, vals)];
}
});
});
- }, n.prototype.arraySync = function() {
- return xi(this.shape, this.dataSync());
- }, n.prototype.data = function() {
- return pe(this, void 0, void 0, function() {
- var t, e;
- return fe(this, function(r) {
- switch (r.label) {
+ };
+ Tensor2.prototype.arraySync = function() {
+ return toNestedArray(this.shape, this.dataSync());
+ };
+ Tensor2.prototype.data = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var data, bytes;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- return this.throwIfDisposed(), t = Cn().read(this.dataId), this.dtype === "string" ? [4, t] : [3, 2];
+ this.throwIfDisposed();
+ data = trackerFn().read(this.dataId);
+ if (!(this.dtype === "string"))
+ return [3, 2];
+ return [4, data];
case 1:
- e = r.sent();
+ bytes = _a.sent();
try {
- return [2, e.map(function(i) {
- return fu(i);
+ return [2, bytes.map(function(b) {
+ return decodeString(b);
})];
- } catch (i) {
+ } catch (_b) {
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
}
- r.label = 2;
+ _a.label = 2;
case 2:
- return [2, t];
+ return [2, data];
}
});
});
- }, n.prototype.dataSync = function() {
+ };
+ Tensor2.prototype.dataSync = function() {
this.throwIfDisposed();
- var t = Cn().readSync(this.dataId);
- if (this.dtype === "string")
+ var data = trackerFn().readSync(this.dataId);
+ if (this.dtype === "string") {
try {
- return t.map(function(e) {
- return fu(e);
+ return data.map(function(b) {
+ return decodeString(b);
});
- } catch (e) {
+ } catch (_a) {
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
}
- return t;
- }, n.prototype.bytes = function() {
- return pe(this, void 0, void 0, function() {
- var t;
- return fe(this, function(e) {
- switch (e.label) {
+ }
+ return data;
+ };
+ Tensor2.prototype.bytes = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var data;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- return this.throwIfDisposed(), [4, Cn().read(this.dataId)];
+ this.throwIfDisposed();
+ return [4, trackerFn().read(this.dataId)];
case 1:
- return t = e.sent(), this.dtype === "string" ? [2, t] : [2, new Uint8Array(t.buffer)];
+ data = _a.sent();
+ if (this.dtype === "string") {
+ return [2, data];
+ } else {
+ return [2, new Uint8Array(data.buffer)];
+ }
}
});
});
- }, n.prototype.dispose = function() {
- if (this.isDisposed)
+ };
+ Tensor2.prototype.dispose = function() {
+ if (this.isDisposed) {
return;
- Cn().disposeTensor(this), this.isDisposedInternal = true;
- }, Object.defineProperty(n.prototype, "isDisposed", {get: function() {
- return this.isDisposedInternal;
- }, enumerable: true, configurable: true}), n.prototype.throwIfDisposed = function() {
- if (this.isDisposed)
+ }
+ trackerFn().disposeTensor(this);
+ this.isDisposedInternal = true;
+ };
+ Object.defineProperty(Tensor2.prototype, "isDisposed", {
+ get: function() {
+ return this.isDisposedInternal;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Tensor2.prototype.throwIfDisposed = function() {
+ if (this.isDisposed) {
throw new Error("Tensor is disposed.");
- }, n.prototype.print = function(t) {
- return t === void 0 && (t = false), Ii.print(this, t);
- }, n.prototype.clone = function() {
- return this.throwIfDisposed(), Ii.clone(this);
- }, n.prototype.toString = function(t) {
- t === void 0 && (t = false);
- var e = this.dataSync();
- return YL(e, this.shape, this.dtype, t);
- }, n.prototype.cast = function(t) {
- return this.throwIfDisposed(), Ii.cast(this, t);
- }, n.prototype.variable = function(t, e, r) {
- return t === void 0 && (t = true), this.throwIfDisposed(), Cn().makeVariable(this, t, e, r);
- }, n;
+ }
+ };
+ Tensor2.prototype.print = function(verbose) {
+ if (verbose === void 0) {
+ verbose = false;
+ }
+ return opHandler.print(this, verbose);
+ };
+ Tensor2.prototype.clone = function() {
+ this.throwIfDisposed();
+ return opHandler.clone(this);
+ };
+ Tensor2.prototype.toString = function(verbose) {
+ if (verbose === void 0) {
+ verbose = false;
+ }
+ var vals = this.dataSync();
+ return tensorToString(vals, this.shape, this.dtype, verbose);
+ };
+ Tensor2.prototype.cast = function(dtype) {
+ this.throwIfDisposed();
+ return opHandler.cast(this, dtype);
+ };
+ Tensor2.prototype.variable = function(trainable, name, dtype) {
+ if (trainable === void 0) {
+ trainable = true;
+ }
+ this.throwIfDisposed();
+ return trackerFn().makeVariable(this, trainable, name, dtype);
+ };
+ return Tensor2;
}();
- Object.defineProperty(K, Symbol.hasInstance, {value: function(n) {
- return !!n && n.data != null && n.dataSync != null && n.throwIfDisposed != null;
- }});
- var La = function(n) {
- qn(t, n);
- function t(e, r, i, a) {
- var s = n.call(this, e.shape, e.dtype, e.dataId, a) || this;
- return s.trainable = r, s.name = i, s;
+ Object.defineProperty(Tensor, Symbol.hasInstance, {
+ value: function(instance2) {
+ return !!instance2 && instance2.data != null && instance2.dataSync != null && instance2.throwIfDisposed != null;
}
- return t.prototype.assign = function(e) {
- if (e.dtype !== this.dtype)
- throw new Error("dtype of the new value (" + e.dtype + ") and " + ("previous value (" + this.dtype + ") must match"));
- if (!pn(e.shape, this.shape))
- throw new Error("shape of the new value (" + e.shape + ") and " + ("previous value (" + this.shape + ") must match"));
- Cn().disposeTensor(this), this.dataId = e.dataId, Cn().incRef(this, null);
- }, t.prototype.dispose = function() {
- Cn().disposeVariable(this), this.isDisposedInternal = true;
- }, t;
- }(K);
- Object.defineProperty(La, Symbol.hasInstance, {value: function(n) {
- return n instanceof K && n.assign != null && n.assign instanceof Function;
- }});
- (function(n) {
- n.R0 = "R0", n.R1 = "R1", n.R2 = "R2", n.R3 = "R3", n.R4 = "R4", n.R5 = "R5", n.R6 = "R6";
- })(A.Rank || (A.Rank = {}));
- var gu;
- (function(n) {
- n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64";
- })(gu || (gu = {}));
- var yu;
- (function(n) {
- n.float32 = "float32", n.int32 = "int32", n.bool = "bool", n.complex64 = "complex64";
- })(yu || (yu = {}));
- var vu;
- (function(n) {
- n.float32 = "float32", n.int32 = "float32", n.bool = "float32", n.complex64 = "complex64";
- })(vu || (vu = {}));
- var wu;
- (function(n) {
- n.float32 = "complex64", n.int32 = "complex64", n.bool = "complex64", n.complex64 = "complex64";
- })(wu || (wu = {}));
- var $L = {float32: vu, int32: gu, bool: yu, complex64: wu};
- function Fs(n, t) {
- if (n === "string" || t === "string") {
- if (n === "string" && t === "string")
+ });
+ var Variable = function(_super) {
+ __extends(Variable2, _super);
+ function Variable2(initialValue, trainable, name, tensorId) {
+ var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this;
+ _this.trainable = trainable;
+ _this.name = name;
+ return _this;
+ }
+ Variable2.prototype.assign = function(newValue) {
+ if (newValue.dtype !== this.dtype) {
+ throw new Error("dtype of the new value (" + newValue.dtype + ") and " + ("previous value (" + this.dtype + ") must match"));
+ }
+ if (!arraysEqual(newValue.shape, this.shape)) {
+ throw new Error("shape of the new value (" + newValue.shape + ") and " + ("previous value (" + this.shape + ") must match"));
+ }
+ trackerFn().disposeTensor(this);
+ this.dataId = newValue.dataId;
+ trackerFn().incRef(this, null);
+ };
+ Variable2.prototype.dispose = function() {
+ trackerFn().disposeVariable(this);
+ this.isDisposedInternal = true;
+ };
+ return Variable2;
+ }(Tensor);
+ Object.defineProperty(Variable, Symbol.hasInstance, {
+ value: function(instance2) {
+ return instance2 instanceof Tensor && instance2.assign != null && instance2.assign instanceof Function;
+ }
+ });
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ (function(Rank) {
+ Rank["R0"] = "R0";
+ Rank["R1"] = "R1";
+ Rank["R2"] = "R2";
+ Rank["R3"] = "R3";
+ Rank["R4"] = "R4";
+ Rank["R5"] = "R5";
+ Rank["R6"] = "R6";
+ })(exports.Rank || (exports.Rank = {}));
+ var UpcastInt32AndMap;
+ (function(UpcastInt32AndMap2) {
+ UpcastInt32AndMap2["float32"] = "float32";
+ UpcastInt32AndMap2["int32"] = "int32";
+ UpcastInt32AndMap2["bool"] = "int32";
+ UpcastInt32AndMap2["complex64"] = "complex64";
+ })(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
+ var UpcastBoolAndMap;
+ (function(UpcastBoolAndMap2) {
+ UpcastBoolAndMap2["float32"] = "float32";
+ UpcastBoolAndMap2["int32"] = "int32";
+ UpcastBoolAndMap2["bool"] = "bool";
+ UpcastBoolAndMap2["complex64"] = "complex64";
+ })(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
+ var UpcastFloat32AndMap;
+ (function(UpcastFloat32AndMap2) {
+ UpcastFloat32AndMap2["float32"] = "float32";
+ UpcastFloat32AndMap2["int32"] = "float32";
+ UpcastFloat32AndMap2["bool"] = "float32";
+ UpcastFloat32AndMap2["complex64"] = "complex64";
+ })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
+ var UpcastComplex64AndMap;
+ (function(UpcastComplex64AndMap2) {
+ UpcastComplex64AndMap2["float32"] = "complex64";
+ UpcastComplex64AndMap2["int32"] = "complex64";
+ UpcastComplex64AndMap2["bool"] = "complex64";
+ UpcastComplex64AndMap2["complex64"] = "complex64";
+ })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
+ var upcastTypeMap = {
+ float32: UpcastFloat32AndMap,
+ int32: UpcastInt32AndMap,
+ bool: UpcastBoolAndMap,
+ complex64: UpcastComplex64AndMap
+ };
+ function upcastType(typeA, typeB) {
+ if (typeA === "string" || typeB === "string") {
+ if (typeA === "string" && typeB === "string") {
return "string";
- throw new Error("Can not upcast " + n + " with " + t);
+ }
+ throw new Error("Can not upcast " + typeA + " with " + typeB);
}
- return $L[n][t];
+ return upcastTypeMap[typeA][typeB];
}
- function XL(n) {
- return Fs(n, "int32");
+ function sumOutType(type) {
+ return upcastType(type, "int32");
}
- function ct(n, t) {
- if (n.dtype === t.dtype)
- return [n, t];
- var e = Fs(n.dtype, t.dtype);
- return [n.cast(e), t.cast(e)];
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function makeTypesMatch(a, b) {
+ if (a.dtype === b.dtype) {
+ return [a, b];
+ }
+ var dtype = upcastType(a.dtype, b.dtype);
+ return [a.cast(dtype), b.cast(dtype)];
}
- function Tm(n, t) {
- E(n.dtype === t.dtype, function() {
- return "The dtypes of the first(" + n.dtype + ") and" + (" second(" + t.dtype + ") input must match");
+ function assertTypesMatch(a, b) {
+ assert(a.dtype === b.dtype, function() {
+ return "The dtypes of the first(" + a.dtype + ") and" + (" second(" + b.dtype + ") input must match");
});
}
- function JL(n, t) {
- return t.some(function(e) {
- return e.id === n.id;
+ function isTensorInList(tensor2, tensorList) {
+ return tensorList.some(function(x) {
+ return x.id === tensor2.id;
});
}
- function bu(n) {
- var t = [], e = new Set();
- return Nm(n, t, e), t;
+ function getTensorsInContainer(result) {
+ var list = [];
+ var seen = new Set();
+ walkTensorContainer(result, list, seen);
+ return list;
}
- function Nm(n, t, e) {
- if (n == null)
- return;
- if (n instanceof K) {
- t.push(n);
+ function walkTensorContainer(container, list, seen) {
+ if (container == null) {
return;
}
- if (!ZL(n))
+ if (container instanceof Tensor) {
+ list.push(container);
return;
- var r = n;
- for (var i in r) {
- var a = r[i];
- e.has(a) || (e.add(a), Nm(a, t, e));
+ }
+ if (!isIterable(container)) {
+ return;
+ }
+ var iterable = container;
+ for (var k in iterable) {
+ var val = iterable[k];
+ if (!seen.has(val)) {
+ seen.add(val);
+ walkTensorContainer(val, list, seen);
+ }
}
}
- function ZL(n) {
- return Array.isArray(n) || typeof n == "object";
+ function isIterable(obj) {
+ return Array.isArray(obj) || typeof obj === "object";
}
- var QL = {__proto__: null, makeTypesMatch: ct, assertTypesMatch: Tm, isTensorInList: JL, getTensorsInContainer: bu};
- var _m = function() {
- function n() {
- this.registeredVariables = {}, this.nextTapeNodeId = 0, this.numBytes = 0, this.numTensors = 0, this.numStringTensors = 0, this.numDataBuffers = 0, this.gradientDepth = 0, this.kernelDepth = 0, this.scopeStack = [], this.numDataMovesStack = [], this.nextScopeId = 0, this.tensorInfo = new WeakMap(), this.profiling = false, this.activeProfile = {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};
+ var tensor_util = {
+ __proto__: null,
+ makeTypesMatch,
+ assertTypesMatch,
+ isTensorInList,
+ getTensorsInContainer
+ };
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var EngineState = function() {
+ function EngineState2() {
+ this.registeredVariables = {};
+ this.nextTapeNodeId = 0;
+ this.numBytes = 0;
+ this.numTensors = 0;
+ this.numStringTensors = 0;
+ this.numDataBuffers = 0;
+ this.gradientDepth = 0;
+ this.kernelDepth = 0;
+ this.scopeStack = [];
+ this.numDataMovesStack = [];
+ this.nextScopeId = 0;
+ this.tensorInfo = new WeakMap();
+ this.profiling = false;
+ this.activeProfile = {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};
}
- return n.prototype.dispose = function() {
- for (var t in this.registeredVariables)
- this.registeredVariables[t].dispose();
- }, n;
- }(), nS = function() {
- function n(t) {
- this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new _m();
+ EngineState2.prototype.dispose = function() {
+ for (var variableName in this.registeredVariables) {
+ this.registeredVariables[variableName].dispose();
+ }
+ };
+ return EngineState2;
+ }();
+ var Engine = function() {
+ function Engine2(ENV2) {
+ this.ENV = ENV2;
+ this.registry = {};
+ this.registryFactory = {};
+ this.pendingBackendInitId = 0;
+ this.state = new EngineState();
}
- return n.prototype.ready = function() {
- return pe(this, void 0, void 0, function() {
- var t, e, r, i;
- return fe(this, function(a) {
- switch (a.label) {
+ Engine2.prototype.ready = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var sortedBackends, i, backendName, success;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- if (this.pendingBackendInit != null)
+ if (this.pendingBackendInit != null) {
return [2, this.pendingBackendInit.then(function() {
})];
- if (this.backendInstance != null)
+ }
+ if (this.backendInstance != null) {
return [2];
- t = this.getSortedBackends(), e = 0, a.label = 1;
+ }
+ sortedBackends = this.getSortedBackends();
+ i = 0;
+ _a.label = 1;
case 1:
- return e < t.length ? (r = t[e], [4, this.initializeBackend(r).success]) : [3, 5];
+ if (!(i < sortedBackends.length))
+ return [3, 5];
+ backendName = sortedBackends[i];
+ return [4, this.initializeBackend(backendName).success];
case 2:
- return i = a.sent(), i ? [4, this.setBackend(r)] : [3, 4];
+ success = _a.sent();
+ if (!success)
+ return [3, 4];
+ return [4, this.setBackend(backendName)];
case 3:
- return a.sent(), [2];
+ _a.sent();
+ return [2];
case 4:
- return e++, [3, 1];
+ i++;
+ return [3, 1];
case 5:
throw new Error("Could not initialize any backends, all backend initializations failed.");
}
});
});
- }, Object.defineProperty(n.prototype, "backend", {get: function() {
- if (this.pendingBackendInit != null)
- throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods");
- if (this.backendInstance == null) {
- var t = this.initializeBackendsAndReturnBest(), e = t.name, r = t.asyncInit;
- if (r)
- throw new Error("The highest priority backend '" + e + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods");
- this.setBackend(e);
- }
- return this.backendInstance;
- }, enumerable: true, configurable: true}), n.prototype.backendNames = function() {
+ };
+ Object.defineProperty(Engine2.prototype, "backend", {
+ get: function() {
+ if (this.pendingBackendInit != null) {
+ throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods");
+ }
+ if (this.backendInstance == null) {
+ var _a = this.initializeBackendsAndReturnBest(), name_1 = _a.name, asyncInit = _a.asyncInit;
+ if (asyncInit) {
+ throw new Error("The highest priority backend '" + name_1 + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods");
+ }
+ this.setBackend(name_1);
+ }
+ return this.backendInstance;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Engine2.prototype.backendNames = function() {
return Object.keys(this.registryFactory);
- }, n.prototype.findBackend = function(t) {
- if (!(t in this.registry))
- if (t in this.registryFactory) {
- var e = this.initializeBackend(t).asyncInit;
- if (e)
+ };
+ Engine2.prototype.findBackend = function(backendName) {
+ if (!(backendName in this.registry)) {
+ if (backendName in this.registryFactory) {
+ var asyncInit = this.initializeBackend(backendName).asyncInit;
+ if (asyncInit) {
return null;
- } else
+ }
+ } else {
return null;
- return this.registry[t];
- }, n.prototype.findBackendFactory = function(t) {
- return t in this.registryFactory ? this.registryFactory[t].factory : null;
- }, n.prototype.registerBackend = function(t, e, r) {
- return r === void 0 && (r = 1), t in this.registryFactory ? (console.warn(t + " backend was already registered. Reusing existing backend factory."), false) : (this.registryFactory[t] = {factory: e, priority: r}, true);
- }, n.prototype.setBackend = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r, i, a, s;
- return fe(this, function(o) {
- switch (o.label) {
+ }
+ }
+ return this.registry[backendName];
+ };
+ Engine2.prototype.findBackendFactory = function(backendName) {
+ if (!(backendName in this.registryFactory)) {
+ return null;
+ }
+ return this.registryFactory[backendName].factory;
+ };
+ Engine2.prototype.registerBackend = function(backendName, factory, priority) {
+ if (priority === void 0) {
+ priority = 1;
+ }
+ if (backendName in this.registryFactory) {
+ console.warn(backendName + " backend was already registered. Reusing existing backend factory.");
+ return false;
+ }
+ this.registryFactory[backendName] = {factory, priority};
+ return true;
+ };
+ Engine2.prototype.setBackend = function(backendName) {
+ return __awaiter(this, void 0, void 0, function() {
+ var _a, success, asyncInit, result, _b;
+ return __generator(this, function(_c) {
+ switch (_c.label) {
case 0:
- if (this.registryFactory[t] == null)
- throw new Error("Backend name '" + t + "' not found in registry");
- return this.backendName = t, this.registry[t] == null ? (this.backendInstance = null, e = this.initializeBackend(t), r = e.success, i = e.asyncInit, i ? [4, r] : [3, 2]) : [3, 4];
+ if (this.registryFactory[backendName] == null) {
+ throw new Error("Backend name '" + backendName + "' not found in registry");
+ }
+ this.backendName = backendName;
+ if (!(this.registry[backendName] == null))
+ return [3, 4];
+ this.backendInstance = null;
+ _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit;
+ if (!asyncInit)
+ return [3, 2];
+ return [4, success];
case 1:
- return s = o.sent(), [3, 3];
+ _b = _c.sent();
+ return [3, 3];
case 2:
- s = r, o.label = 3;
+ _b = success;
+ _c.label = 3;
case 3:
- if (a = s, !a)
+ result = _b;
+ if (!result) {
return [2, false];
- o.label = 4;
+ }
+ _c.label = 4;
case 4:
- return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new HL(this.backendInstance), [2, true];
+ this.backendInstance = this.registry[backendName];
+ this.setupRegisteredKernels();
+ this.profiler = new Profiler(this.backendInstance);
+ return [2, true];
}
});
});
- }, n.prototype.setupRegisteredKernels = function() {
- var t = this, e = Os(this.backendName);
- e.forEach(function(r) {
- r.setupFunc != null && r.setupFunc(t.backendInstance);
+ };
+ Engine2.prototype.setupRegisteredKernels = function() {
+ var _this = this;
+ var kernels = getKernelsForBackend(this.backendName);
+ kernels.forEach(function(kernel) {
+ if (kernel.setupFunc != null) {
+ kernel.setupFunc(_this.backendInstance);
+ }
});
- }, n.prototype.disposeRegisteredKernels = function(t) {
- var e = this, r = Os(t);
- r.forEach(function(i) {
- i.disposeFunc != null && i.disposeFunc(e.registry[t]);
+ };
+ Engine2.prototype.disposeRegisteredKernels = function(backendName) {
+ var _this = this;
+ var kernels = getKernelsForBackend(backendName);
+ kernels.forEach(function(kernel) {
+ if (kernel.disposeFunc != null) {
+ kernel.disposeFunc(_this.registry[backendName]);
+ }
});
- }, n.prototype.initializeBackend = function(t) {
- var e = this, r = this.registryFactory[t];
- if (r == null)
- throw new Error("Cannot initialize backend " + t + ", no registration found.");
- try {
- var i = r.factory();
- if (i && !(i instanceof cf) && typeof i.then == "function") {
- var a = ++this.pendingBackendInitId, s = i.then(function(o) {
- return a < e.pendingBackendInitId ? false : (e.registry[t] = o, e.pendingBackendInit = null, true);
- }).catch(function(o) {
- return a < e.pendingBackendInitId || (e.pendingBackendInit = null, console.warn("Initialization of backend " + t + " failed"), console.warn(o.stack || o.message)), false;
- });
- return this.pendingBackendInit = s, {success: s, asyncInit: true};
- } else
- return this.registry[t] = i, {success: true, asyncInit: false};
- } catch (o) {
- return console.warn("Initialization of backend " + t + " failed"), console.warn(o.stack || o.message), {success: false, asyncInit: false};
+ };
+ Engine2.prototype.initializeBackend = function(backendName) {
+ var _this = this;
+ var registryFactoryEntry = this.registryFactory[backendName];
+ if (registryFactoryEntry == null) {
+ throw new Error("Cannot initialize backend " + backendName + ", no registration found.");
}
- }, n.prototype.removeBackend = function(t) {
- if (!(t in this.registryFactory))
- throw new Error(t + " backend not found in registry");
- this.backendName === t && this.pendingBackendInit != null && this.pendingBackendInitId++, t in this.registry && (this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t]), delete this.registryFactory[t], this.backendName === t && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null);
- }, n.prototype.getSortedBackends = function() {
- var t = this;
- if (Object.keys(this.registryFactory).length === 0)
+ try {
+ var backend2 = registryFactoryEntry.factory();
+ if (backend2 && !(backend2 instanceof KernelBackend) && typeof backend2.then === "function") {
+ var promiseId_1 = ++this.pendingBackendInitId;
+ var success = backend2.then(function(backendInstance) {
+ if (promiseId_1 < _this.pendingBackendInitId) {
+ return false;
+ }
+ _this.registry[backendName] = backendInstance;
+ _this.pendingBackendInit = null;
+ return true;
+ }).catch(function(err) {
+ if (promiseId_1 < _this.pendingBackendInitId) {
+ return false;
+ }
+ _this.pendingBackendInit = null;
+ console.warn("Initialization of backend " + backendName + " failed");
+ console.warn(err.stack || err.message);
+ return false;
+ });
+ this.pendingBackendInit = success;
+ return {success, asyncInit: true};
+ } else {
+ this.registry[backendName] = backend2;
+ return {success: true, asyncInit: false};
+ }
+ } catch (err) {
+ console.warn("Initialization of backend " + backendName + " failed");
+ console.warn(err.stack || err.message);
+ return {success: false, asyncInit: false};
+ }
+ };
+ Engine2.prototype.removeBackend = function(backendName) {
+ if (!(backendName in this.registryFactory)) {
+ throw new Error(backendName + " backend not found in registry");
+ }
+ if (this.backendName === backendName && this.pendingBackendInit != null) {
+ this.pendingBackendInitId++;
+ }
+ if (backendName in this.registry) {
+ this.disposeRegisteredKernels(backendName);
+ this.registry[backendName].dispose();
+ delete this.registry[backendName];
+ }
+ delete this.registryFactory[backendName];
+ if (this.backendName === backendName) {
+ this.pendingBackendInit = null;
+ this.backendName = null;
+ this.backendInstance = null;
+ }
+ };
+ Engine2.prototype.getSortedBackends = function() {
+ var _this = this;
+ if (Object.keys(this.registryFactory).length === 0) {
throw new Error("No backend found in registry.");
- return Object.keys(this.registryFactory).sort(function(e, r) {
- return t.registryFactory[r].priority - t.registryFactory[e].priority;
+ }
+ return Object.keys(this.registryFactory).sort(function(a, b) {
+ return _this.registryFactory[b].priority - _this.registryFactory[a].priority;
});
- }, n.prototype.initializeBackendsAndReturnBest = function() {
- for (var t = this.getSortedBackends(), e = 0; e < t.length; e++) {
- var r = t[e], i = this.initializeBackend(r), a = i.success, s = i.asyncInit;
- if (s || a)
- return {name: r, asyncInit: s};
+ };
+ Engine2.prototype.initializeBackendsAndReturnBest = function() {
+ var sortedBackends = this.getSortedBackends();
+ for (var i = 0; i < sortedBackends.length; i++) {
+ var backendName = sortedBackends[i];
+ var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit;
+ if (asyncInit || success) {
+ return {name: backendName, asyncInit};
+ }
}
throw new Error("Could not initialize any backends, all backend initializations failed.");
- }, n.prototype.moveData = function(t, e) {
- var r = this.state.tensorInfo.get(e), i = r.backend, a = this.readSync(e);
- i.disposeData(e), r.backend = t, t.move(e, a, r.shape, r.dtype), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
- }, n.prototype.tidy = function(t, e) {
- var r = this, i = null;
- if (e == null) {
- if (typeof t != "function")
+ };
+ Engine2.prototype.moveData = function(backend2, dataId) {
+ var info = this.state.tensorInfo.get(dataId);
+ var srcBackend = info.backend;
+ var values = this.readSync(dataId);
+ srcBackend.disposeData(dataId);
+ info.backend = backend2;
+ backend2.move(dataId, values, info.shape, info.dtype);
+ if (this.shouldCheckForMemLeaks()) {
+ this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
+ }
+ };
+ Engine2.prototype.tidy = function(nameOrFn, fn) {
+ var _this = this;
+ var name = null;
+ if (fn == null) {
+ if (typeof nameOrFn !== "function") {
throw new Error("Please provide a function to tidy()");
- e = t;
+ }
+ fn = nameOrFn;
} else {
- if (typeof t != "string" && !(t instanceof String))
+ if (typeof nameOrFn !== "string" && !(nameOrFn instanceof String)) {
throw new Error("When calling with two arguments, the first argument to tidy() must be a string");
- if (typeof e != "function")
+ }
+ if (typeof fn !== "function") {
throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function");
- i = t;
+ }
+ name = nameOrFn;
}
- var a;
+ var result;
return this.scopedRun(function() {
- return r.startScope(i);
+ return _this.startScope(name);
}, function() {
- return r.endScope(a);
+ return _this.endScope(result);
}, function() {
- return a = e(), a instanceof Promise && console.error("Cannot return a Promise inside of tidy."), a;
+ result = fn();
+ if (result instanceof Promise) {
+ console.error("Cannot return a Promise inside of tidy.");
+ }
+ return result;
});
- }, n.prototype.scopedRun = function(t, e, r) {
- t();
+ };
+ Engine2.prototype.scopedRun = function(start, end, f) {
+ start();
try {
- var i = r();
- return e(), i;
- } catch (a) {
- throw e(), a;
+ var res = f();
+ end();
+ return res;
+ } catch (ex) {
+ end();
+ throw ex;
}
- }, n.prototype.nextTensorId = function() {
- return n.nextTensorId++;
- }, n.prototype.nextVariableId = function() {
- return n.nextVariableId++;
- }, n.prototype.clone = function(t) {
- var e = this.makeTensorFromDataId(t.dataId, t.shape, t.dtype), r = {x: t}, i = function(s) {
- return {x: function() {
- var o = "float32", c = {x: s}, l = {dtype: o};
- return z.runKernelFunc(function(u) {
- return u.cast(s, o);
- }, c, null, Rs, l);
- }};
- }, a = [];
- return this.addTapeNode(this.state.activeScope.name, r, [e], i, a, {}), e;
- }, n.prototype.runKernel = function(t, e, r, i, a) {
- var s = null, o = null;
- return this.runKernelFunc(s, e, o, t, r, i, a);
- }, n.prototype.shouldCheckForMemLeaks = function() {
- return this.ENV.getBool("IS_TEST");
- }, n.prototype.checkKernelForMemLeak = function(t, e, r) {
- var i = this.backend.numDataIds(), a = 0;
- r.forEach(function(c) {
- a += c.dtype === "complex64" ? 3 : 1;
- });
- var s = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], o = i - e - a - s;
- if (o > 0)
- throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + ("(" + o + " data ids) after running '" + t + "'"));
- }, n.prototype.runKernelFunc = function(t, e, r, i, a, s, o) {
- var c = this, l, u = [], h = this.isTapeOn();
- i == null && (i = this.state.activeScope != null ? this.state.activeScope.name : "");
- var d = this.state.numBytes, p = this.state.numTensors;
- this.shouldCheckForMemLeaks() && this.state.numDataMovesStack.push(0);
- var f, m = uu(i, this.backendName), g;
- if (m != null)
- f = function() {
- var b = c.backend.numDataIds();
- g = m.kernelFunc({inputs: e, attrs: a, backend: c.backend});
- var L = Array.isArray(g) ? g : [g];
- c.shouldCheckForMemLeaks() && c.checkKernelForMemLeak(i, b, L);
- var x = L.map(function(C) {
- var O = C.dataId, D = C.shape, F = C.dtype;
- return c.makeTensorFromDataId(O, D, F);
- });
- if (h) {
- var N = c.getTensorsForGradient(i, e, x);
- if (N == null) {
- o == null && (o = []);
- var I = x.filter(function(C, O) {
- return o[O];
- });
- N = (s || []).slice().concat(I);
- }
- u = c.saveTensorsForBackwardMode(N);
+ };
+ Engine2.prototype.nextTensorId = function() {
+ return Engine2.nextTensorId++;
+ };
+ Engine2.prototype.nextVariableId = function() {
+ return Engine2.nextVariableId++;
+ };
+ Engine2.prototype.clone = function(x) {
+ var y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);
+ var inputs = {x};
+ var grad2 = function(dy) {
+ return {
+ x: function() {
+ var dtype = "float32";
+ var gradInputs = {x: dy};
+ var attrs = {dtype};
+ return ENGINE.runKernelFunc(function(backend2) {
+ return backend2.cast(dy, dtype);
+ }, gradInputs, null, Cast, attrs);
}
- return x;
};
- else {
- var y = function(b) {
- if (!h)
+ };
+ var saved = [];
+ this.addTapeNode(this.state.activeScope.name, inputs, [y], grad2, saved, {});
+ return y;
+ };
+ Engine2.prototype.runKernel = function(kernelName, inputs, attrs, inputsToSave, outputsToSave) {
+ var forwardFunc = null;
+ var backwardsFunc = null;
+ return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave);
+ };
+ Engine2.prototype.shouldCheckForMemLeaks = function() {
+ return this.ENV.getBool("IS_TEST");
+ };
+ Engine2.prototype.checkKernelForMemLeak = function(kernelName, numDataIdsBefore, outInfos) {
+ var numDataIdsAfter = this.backend.numDataIds();
+ var numOutputDataIds = 0;
+ outInfos.forEach(function(info) {
+ numOutputDataIds += info.dtype === "complex64" ? 3 : 1;
+ });
+ var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
+ var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
+ if (dataIdsLeaked > 0) {
+ throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'"));
+ }
+ };
+ Engine2.prototype.runKernelFunc = function(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) {
+ var _this = this;
+ var outputs;
+ var saved = [];
+ var isTapeOn = this.isTapeOn();
+ if (kernelName == null) {
+ kernelName = this.state.activeScope != null ? this.state.activeScope.name : "";
+ }
+ var startingBytecount = this.state.numBytes;
+ var startingNumTensors = this.state.numTensors;
+ if (this.shouldCheckForMemLeaks()) {
+ this.state.numDataMovesStack.push(0);
+ }
+ var kernelFunc;
+ var kernel = getKernel(kernelName, this.backendName);
+ var out;
+ if (kernel != null) {
+ kernelFunc = function() {
+ var numDataIdsBefore = _this.backend.numDataIds();
+ out = kernel.kernelFunc({inputs, attrs, backend: _this.backend});
+ var outInfos = Array.isArray(out) ? out : [out];
+ if (_this.shouldCheckForMemLeaks()) {
+ _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
+ }
+ var outTensors = outInfos.map(function(_a) {
+ var dataId = _a.dataId, shape = _a.shape, dtype = _a.dtype;
+ return _this.makeTensorFromDataId(dataId, shape, dtype);
+ });
+ if (isTapeOn) {
+ var tensorsToSave = _this.getTensorsForGradient(kernelName, inputs, outTensors);
+ if (tensorsToSave == null) {
+ if (outputsToSave == null) {
+ outputsToSave = [];
+ }
+ var outsToSave = outTensors.filter(function(_, i) {
+ return outputsToSave[i];
+ });
+ tensorsToSave = (inputsToSave || []).slice().concat(outsToSave);
+ }
+ saved = _this.saveTensorsForBackwardMode(tensorsToSave);
+ }
+ return outTensors;
+ };
+ } else {
+ var saveFunc_1 = function(tensors) {
+ if (!isTapeOn) {
return;
- u = b.map(function(L) {
- return c.keep(c.clone(L));
+ }
+ saved = tensors.map(function(tensor2) {
+ return _this.keep(_this.clone(tensor2));
});
};
- f = function() {
- var b = c.backend.numDataIds();
- g = c.tidy(function() {
- return t(c.backend, y);
+ kernelFunc = function() {
+ var numDataIdsBefore = _this.backend.numDataIds();
+ out = _this.tidy(function() {
+ return forwardFunc(_this.backend, saveFunc_1);
});
- var L = Array.isArray(g) ? g : [g];
- return c.shouldCheckForMemLeaks() && c.checkKernelForMemLeak(i, b, L), L;
+ var outs = Array.isArray(out) ? out : [out];
+ if (_this.shouldCheckForMemLeaks()) {
+ _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);
+ }
+ return outs;
};
}
- var w;
- return this.scopedRun(function() {
- return c.state.kernelDepth++;
+ var kernelProfile;
+ this.scopedRun(function() {
+ return _this.state.kernelDepth++;
}, function() {
- return c.state.kernelDepth--;
+ return _this.state.kernelDepth--;
}, function() {
- !c.ENV.getBool("DEBUG") && !c.state.profiling ? l = f() : (w = c.profiler.profileKernel(i, e, function() {
- return f();
- }), c.ENV.getBool("DEBUG") && c.profiler.logKernelProfile(w), l = w.outputs);
- }), h && this.addTapeNode(i, e, l, r, u, a), this.state.profiling && this.state.activeProfile.kernels.push({name: i, bytesAdded: this.state.numBytes - d, totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - p, totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(e).map(function(b) {
- return e[b] != null ? e[b].shape : null;
- }), outputShapes: l.map(function(b) {
- return b.shape;
- }), kernelTimeMs: w.timeMs, extraInfo: w.extraInfo}), Array.isArray(g) ? l : l[0];
- }, n.prototype.saveTensorsForBackwardMode = function(t) {
- var e = this, r = t.map(function(i) {
- return e.keep(e.clone(i));
+ if (!_this.ENV.getBool("DEBUG") && !_this.state.profiling) {
+ outputs = kernelFunc();
+ } else {
+ kernelProfile = _this.profiler.profileKernel(kernelName, inputs, function() {
+ return kernelFunc();
+ });
+ if (_this.ENV.getBool("DEBUG")) {
+ _this.profiler.logKernelProfile(kernelProfile);
+ }
+ outputs = kernelProfile.outputs;
+ }
});
- return r;
- }, n.prototype.getTensorsForGradient = function(t, e, r) {
- var i = hu(t);
- if (i != null) {
- var a = i.inputsToSave || [], s = i.outputsToSave || [], o = void 0;
- i.saveAllInputs ? (E(Array.isArray(e), function() {
- return "saveAllInputs is true, expected inputs to be an array.";
- }), o = Object.keys(e).map(function(l) {
- return e[l];
- })) : o = a.map(function(l) {
- return e[l];
+ if (isTapeOn) {
+ this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs);
+ }
+ if (this.state.profiling) {
+ this.state.activeProfile.kernels.push({
+ name: kernelName,
+ bytesAdded: this.state.numBytes - startingBytecount,
+ totalBytesSnapshot: this.state.numBytes,
+ tensorsAdded: this.state.numTensors - startingNumTensors,
+ totalTensorsSnapshot: this.state.numTensors,
+ inputShapes: Object.keys(inputs).map(function(key) {
+ return inputs[key] != null ? inputs[key].shape : null;
+ }),
+ outputShapes: outputs.map(function(item) {
+ return item.shape;
+ }),
+ kernelTimeMs: kernelProfile.timeMs,
+ extraInfo: kernelProfile.extraInfo
});
- var c = r.filter(function(l, u) {
- return s[u];
+ }
+ return Array.isArray(out) ? outputs : outputs[0];
+ };
+ Engine2.prototype.saveTensorsForBackwardMode = function(tensors) {
+ var _this = this;
+ var saved = tensors.map(function(tensor2) {
+ return _this.keep(_this.clone(tensor2));
+ });
+ return saved;
+ };
+ Engine2.prototype.getTensorsForGradient = function(kernelName, inputs, outputs) {
+ var gradConfig = getGradient(kernelName);
+ if (gradConfig != null) {
+ var inputsToSave = gradConfig.inputsToSave || [];
+ var outputsToSave_1 = gradConfig.outputsToSave || [];
+ var inputTensorsToSave = void 0;
+ if (gradConfig.saveAllInputs) {
+ assert(Array.isArray(inputs), function() {
+ return "saveAllInputs is true, expected inputs to be an array.";
+ });
+ inputTensorsToSave = Object.keys(inputs).map(function(key) {
+ return inputs[key];
+ });
+ } else {
+ inputTensorsToSave = inputsToSave.map(function(inputName) {
+ return inputs[inputName];
+ });
+ }
+ var outputTensorsToSave = outputs.filter(function(_, i) {
+ return outputsToSave_1[i];
});
- return o.concat(c);
+ return inputTensorsToSave.concat(outputTensorsToSave);
}
return null;
- }, n.prototype.makeTensor = function(t, e, r, i) {
- if (t == null)
+ };
+ Engine2.prototype.makeTensor = function(values, shape, dtype, backend2) {
+ if (values == null) {
throw new Error("Values passed to engine.makeTensor() are null");
- r = r || "float32", i = i || this.backend;
- var a = t;
- r === "string" && or(t[0]) && (a = t.map(function(u) {
- return du(u);
- }));
- var s = i.write(a, e, r), o = new K(e, r, s, this.nextTensorId());
- if (this.incRef(o, i), r === "string") {
- var c = this.state.tensorInfo.get(s), l = yf(a);
- this.state.numBytes += l - c.bytes, c.bytes = l;
}
- return o;
- }, n.prototype.makeTensorFromDataId = function(t, e, r, i) {
- r = r || "float32";
- var a = new K(e, r, t, this.nextTensorId());
- return this.incRef(a, i), a;
- }, n.prototype.makeVariable = function(t, e, r, i) {
- e === void 0 && (e = true), r = r || this.nextVariableId().toString(), i != null && i !== t.dtype && (t = t.cast(i));
- var a = new La(t, e, r, this.nextTensorId());
- if (this.state.registeredVariables[a.name] != null)
- throw new Error("Variable with name " + a.name + " was already registered");
- return this.state.registeredVariables[a.name] = a, this.incRef(a, this.backend), a;
- }, n.prototype.incRef = function(t, e) {
- var r = this.state.tensorInfo.has(t.dataId) ? this.state.tensorInfo.get(t.dataId).refCount : 0;
- if (this.state.numTensors++, t.dtype === "string" && this.state.numStringTensors++, r === 0) {
+ dtype = dtype || "float32";
+ backend2 = backend2 || this.backend;
+ var backendVals = values;
+ if (dtype === "string" && isString(values[0])) {
+ backendVals = values.map(function(d) {
+ return encodeString(d);
+ });
+ }
+ var dataId = backend2.write(backendVals, shape, dtype);
+ var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
+ this.incRef(t, backend2);
+ if (dtype === "string") {
+ var info = this.state.tensorInfo.get(dataId);
+ var newBytes = bytesFromStringArray(backendVals);
+ this.state.numBytes += newBytes - info.bytes;
+ info.bytes = newBytes;
+ }
+ return t;
+ };
+ Engine2.prototype.makeTensorFromDataId = function(dataId, shape, dtype, backend2) {
+ dtype = dtype || "float32";
+ var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
+ this.incRef(t, backend2);
+ return t;
+ };
+ Engine2.prototype.makeVariable = function(initialValue, trainable, name, dtype) {
+ if (trainable === void 0) {
+ trainable = true;
+ }
+ name = name || this.nextVariableId().toString();
+ if (dtype != null && dtype !== initialValue.dtype) {
+ initialValue = initialValue.cast(dtype);
+ }
+ var v = new Variable(initialValue, trainable, name, this.nextTensorId());
+ if (this.state.registeredVariables[v.name] != null) {
+ throw new Error("Variable with name " + v.name + " was already registered");
+ }
+ this.state.registeredVariables[v.name] = v;
+ this.incRef(v, this.backend);
+ return v;
+ };
+ Engine2.prototype.incRef = function(a, backend2) {
+ var refCount = this.state.tensorInfo.has(a.dataId) ? this.state.tensorInfo.get(a.dataId).refCount : 0;
+ this.state.numTensors++;
+ if (a.dtype === "string") {
+ this.state.numStringTensors++;
+ }
+ if (refCount === 0) {
this.state.numDataBuffers++;
- var i = 0;
- t.dtype !== "complex64" && t.dtype !== "string" && (i = t.size * gf(t.dtype)), this.state.tensorInfo.set(t.dataId, {backend: e || this.backend, dtype: t.dtype, shape: t.shape, bytes: i, refCount: 0}), this.state.numBytes += i;
+ var bytes = 0;
+ if (a.dtype !== "complex64" && a.dtype !== "string") {
+ bytes = a.size * bytesPerElement(a.dtype);
+ }
+ this.state.tensorInfo.set(a.dataId, {
+ backend: backend2 || this.backend,
+ dtype: a.dtype,
+ shape: a.shape,
+ bytes,
+ refCount: 0
+ });
+ this.state.numBytes += bytes;
}
- this.state.tensorInfo.get(t.dataId).refCount++, t instanceof La || this.track(t);
- }, n.prototype.disposeTensor = function(t) {
- if (!this.state.tensorInfo.has(t.dataId))
+ this.state.tensorInfo.get(a.dataId).refCount++;
+ if (!(a instanceof Variable)) {
+ this.track(a);
+ }
+ };
+ Engine2.prototype.disposeTensor = function(a) {
+ if (!this.state.tensorInfo.has(a.dataId)) {
return;
- this.state.numTensors--, t.dtype === "string" && this.state.numStringTensors--;
- var e = this.state.tensorInfo.get(t.dataId), r = e.refCount;
- r <= 1 ? (t.dtype !== "complex64" && (this.state.numBytes -= e.bytes), this.state.numDataBuffers--, e.backend.disposeData(t.dataId), this.state.tensorInfo.delete(t.dataId)) : this.state.tensorInfo.get(t.dataId).refCount--;
- }, n.prototype.disposeVariables = function() {
- for (var t in this.state.registeredVariables) {
- var e = this.state.registeredVariables[t];
- this.disposeVariable(e);
}
- }, n.prototype.disposeVariable = function(t) {
- this.disposeTensor(t), this.state.registeredVariables[t.name] != null && delete this.state.registeredVariables[t.name];
- }, n.prototype.memory = function() {
- var t = this.backend.memory();
- return t.numTensors = this.state.numTensors, t.numDataBuffers = this.state.numDataBuffers, t.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (t.unreliable = true, t.reasons == null && (t.reasons = []), t.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), t;
- }, n.prototype.profile = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r, i, a, s, o, c, l;
- return fe(this, function(u) {
- switch (u.label) {
+ this.state.numTensors--;
+ if (a.dtype === "string") {
+ this.state.numStringTensors--;
+ }
+ var info = this.state.tensorInfo.get(a.dataId);
+ var refCount = info.refCount;
+ if (refCount <= 1) {
+ if (a.dtype !== "complex64") {
+ this.state.numBytes -= info.bytes;
+ }
+ this.state.numDataBuffers--;
+ info.backend.disposeData(a.dataId);
+ this.state.tensorInfo.delete(a.dataId);
+ } else {
+ this.state.tensorInfo.get(a.dataId).refCount--;
+ }
+ };
+ Engine2.prototype.disposeVariables = function() {
+ for (var varName in this.state.registeredVariables) {
+ var v = this.state.registeredVariables[varName];
+ this.disposeVariable(v);
+ }
+ };
+ Engine2.prototype.disposeVariable = function(v) {
+ this.disposeTensor(v);
+ if (this.state.registeredVariables[v.name] != null) {
+ delete this.state.registeredVariables[v.name];
+ }
+ };
+ Engine2.prototype.memory = function() {
+ var info = this.backend.memory();
+ info.numTensors = this.state.numTensors;
+ info.numDataBuffers = this.state.numDataBuffers;
+ info.numBytes = this.state.numBytes;
+ if (this.state.numStringTensors > 0) {
+ info.unreliable = true;
+ if (info.reasons == null) {
+ info.reasons = [];
+ }
+ info.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)");
+ }
+ return info;
+ };
+ Engine2.prototype.profile = function(query) {
+ return __awaiter(this, void 0, void 0, function() {
+ var startBytes, startNumTensors, _a, _i2, _b, kernel, _c, _d;
+ return __generator(this, function(_e) {
+ switch (_e.label) {
case 0:
- return this.state.profiling = true, e = this.state.numBytes, r = this.state.numTensors, this.state.activeProfile.kernels = [], i = this.state.activeProfile, [4, t()];
+ this.state.profiling = true;
+ startBytes = this.state.numBytes;
+ startNumTensors = this.state.numTensors;
+ this.state.activeProfile.kernels = [];
+ _a = this.state.activeProfile;
+ return [4, query()];
case 1:
- i.result = u.sent(), this.state.profiling = false, this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function(h) {
- return h.totalBytesSnapshot;
- })), this.state.activeProfile.newBytes = this.state.numBytes - e, this.state.activeProfile.newTensors = this.state.numTensors - r, a = 0, s = this.state.activeProfile.kernels, u.label = 2;
+ _a.result = _e.sent();
+ this.state.profiling = false;
+ this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function(d) {
+ return d.totalBytesSnapshot;
+ }));
+ this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
+ this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors;
+ _i2 = 0, _b = this.state.activeProfile.kernels;
+ _e.label = 2;
case 2:
- return a < s.length ? (o = s[a], c = o, [4, o.kernelTimeMs]) : [3, 6];
+ if (!(_i2 < _b.length))
+ return [3, 6];
+ kernel = _b[_i2];
+ _c = kernel;
+ return [4, kernel.kernelTimeMs];
case 3:
- return c.kernelTimeMs = u.sent(), l = o, [4, o.extraInfo];
+ _c.kernelTimeMs = _e.sent();
+ _d = kernel;
+ return [4, kernel.extraInfo];
case 4:
- l.extraInfo = u.sent(), u.label = 5;
+ _d.extraInfo = _e.sent();
+ _e.label = 5;
case 5:
- return a++, [3, 2];
+ _i2++;
+ return [3, 2];
case 6:
return [2, this.state.activeProfile];
}
});
});
- }, n.prototype.isTapeOn = function() {
+ };
+ Engine2.prototype.isTapeOn = function() {
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
- }, n.prototype.addTapeNode = function(t, e, r, i, a, s) {
- var o = this, c = {id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: r, saved: a}, l = hu(t);
- l != null && (i = l.gradFunc), i != null && (c.gradient = function(u) {
- return u = u.map(function(h, d) {
- if (h == null) {
- var p = r[d], f = Li(p.size, p.dtype);
- return o.makeTensor(f, p.shape, p.dtype);
- }
- return h;
- }), i(u.length > 1 ? u : u[0], a, s);
- }), this.state.activeTape.push(c);
- }, n.prototype.keep = function(t) {
- return t.kept = true, t;
- }, n.prototype.startTape = function() {
- this.state.gradientDepth === 0 && (this.state.activeTape = []), this.state.gradientDepth++;
- }, n.prototype.endTape = function() {
- this.state.gradientDepth--;
- }, n.prototype.startScope = function(t) {
- var e = {track: [], name: "unnamed scope", id: this.state.nextScopeId++};
- t && (e.name = t), this.state.scopeStack.push(e), this.state.activeScope = e;
- }, n.prototype.endScope = function(t) {
- for (var e = this, r = bu(t), i = new Set(r.map(function(c) {
- return c.id;
- })), a = 0; a < this.state.activeScope.track.length; a++) {
- var s = this.state.activeScope.track[a];
- !s.kept && !i.has(s.id) && s.dispose();
+ };
+ Engine2.prototype.addTapeNode = function(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
+ var _this = this;
+ var tapeNode = {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};
+ var gradConfig = getGradient(kernelName);
+ if (gradConfig != null) {
+ gradientsFunc = gradConfig.gradFunc;
}
- var o = this.state.scopeStack.pop();
- this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], r.forEach(function(c) {
- !c.kept && c.scopeId === o.id && e.track(c);
+ if (gradientsFunc != null) {
+ tapeNode.gradient = function(dys) {
+ dys = dys.map(function(dy, i) {
+ if (dy == null) {
+ var output = outputs[i];
+ var vals = makeZerosTypedArray(output.size, output.dtype);
+ return _this.makeTensor(vals, output.shape, output.dtype);
+ }
+ return dy;
+ });
+ return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
+ };
+ }
+ this.state.activeTape.push(tapeNode);
+ };
+ Engine2.prototype.keep = function(result) {
+ result.kept = true;
+ return result;
+ };
+ Engine2.prototype.startTape = function() {
+ if (this.state.gradientDepth === 0) {
+ this.state.activeTape = [];
+ }
+ this.state.gradientDepth++;
+ };
+ Engine2.prototype.endTape = function() {
+ this.state.gradientDepth--;
+ };
+ Engine2.prototype.startScope = function(name) {
+ var scopeInfo = {
+ track: [],
+ name: "unnamed scope",
+ id: this.state.nextScopeId++
+ };
+ if (name) {
+ scopeInfo.name = name;
+ }
+ this.state.scopeStack.push(scopeInfo);
+ this.state.activeScope = scopeInfo;
+ };
+ Engine2.prototype.endScope = function(result) {
+ var _this = this;
+ var tensorsToTrackInParent = getTensorsInContainer(result);
+ var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function(t) {
+ return t.id;
+ }));
+ for (var i = 0; i < this.state.activeScope.track.length; i++) {
+ var tensor2 = this.state.activeScope.track[i];
+ if (!tensor2.kept && !tensorsToTrackInParentSet.has(tensor2.id)) {
+ tensor2.dispose();
+ }
+ }
+ var oldScope = this.state.scopeStack.pop();
+ this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1];
+ tensorsToTrackInParent.forEach(function(tensor3) {
+ if (!tensor3.kept && tensor3.scopeId === oldScope.id) {
+ _this.track(tensor3);
+ }
});
- }, n.prototype.gradients = function(t, e, r, i) {
- var a = this;
- if (i === void 0 && (i = false), E(e.length > 0, function() {
+ };
+ Engine2.prototype.gradients = function(f, xs, dy, allowNoGradients) {
+ var _this = this;
+ if (allowNoGradients === void 0) {
+ allowNoGradients = false;
+ }
+ assert(xs.length > 0, function() {
return "gradients() received an empty list of xs.";
- }), r != null && r.dtype !== "float32")
- throw new Error("dy must have 'float32' dtype, but has '" + r.dtype + "'");
- var s = this.scopedRun(function() {
- return a.startTape();
- }, function() {
- return a.endTape();
- }, function() {
- return a.tidy("forward", t);
});
- E(s instanceof K, function() {
+ if (dy != null && dy.dtype !== "float32") {
+ throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'");
+ }
+ var y = this.scopedRun(function() {
+ return _this.startTape();
+ }, function() {
+ return _this.endTape();
+ }, function() {
+ return _this.tidy("forward", f);
+ });
+ assert(y instanceof Tensor, function() {
return "The result y returned by f() must be a tensor.";
});
- var o = VL(this.state.activeTape, e, s);
- if (!i && o.length === 0 && e.length > 0)
+ var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
+ if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.");
+ }
return this.tidy("backward", function() {
- var c = {};
- c[s.id] = r == null ? eS(s.shape) : r, GL(c, o, function(u) {
- return a.tidy(u);
- }, tS);
- var l = e.map(function(u) {
- return c[u.id];
+ var accumulatedGradientMap = {};
+ accumulatedGradientMap[y.id] = dy == null ? ones(y.shape) : dy;
+ backpropagateGradients(accumulatedGradientMap, filteredTape, function(f2) {
+ return _this.tidy(f2);
+ }, add);
+ var grads2 = xs.map(function(x) {
+ return accumulatedGradientMap[x.id];
});
- return a.state.gradientDepth === 0 && (a.state.activeTape.forEach(function(u) {
- for (var h = 0, d = u.saved; h < d.length; h++) {
- var p = d[h];
- p.dispose();
- }
- }), a.state.activeTape = null), {value: s, grads: l};
+ if (_this.state.gradientDepth === 0) {
+ _this.state.activeTape.forEach(function(node) {
+ for (var _i2 = 0, _a = node.saved; _i2 < _a.length; _i2++) {
+ var tensor2 = _a[_i2];
+ tensor2.dispose();
+ }
+ });
+ _this.state.activeTape = null;
+ }
+ return {value: y, grads: grads2};
});
- }, n.prototype.customGrad = function(t) {
- var e = this;
- return E(cr(t), function() {
+ };
+ Engine2.prototype.customGrad = function(f) {
+ var _this = this;
+ assert(isFunction(f), function() {
return "The f passed in customGrad(f) must be a function.";
- }), function() {
- for (var r = [], i = 0; i < arguments.length; i++)
- r[i] = arguments[i];
- E(r.every(function(o) {
- return o instanceof K;
+ });
+ return function() {
+ var inputs = [];
+ for (var _i2 = 0; _i2 < arguments.length; _i2++) {
+ inputs[_i2] = arguments[_i2];
+ }
+ assert(inputs.every(function(t) {
+ return t instanceof Tensor;
}), function() {
return "The args passed in customGrad(f)(x1, x2,...) must all be tensors";
});
- var a, s = {};
- return r.forEach(function(o, c) {
- s[c] = o;
- }), e.runKernelFunc(function(o, c) {
- return a = t.apply(void 0, r.concat([c])), E(a.value instanceof K, function() {
+ var res;
+ var inputMap = {};
+ inputs.forEach(function(input, i) {
+ inputMap[i] = input;
+ });
+ return _this.runKernelFunc(function(_, save) {
+ res = f.apply(void 0, inputs.concat([save]));
+ assert(res.value instanceof Tensor, function() {
return "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor";
- }), E(cr(a.gradFunc), function() {
+ });
+ assert(isFunction(res.gradFunc), function() {
return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function.";
- }), a.value;
- }, s, function(o, c) {
- var l = a.gradFunc(o, c), u = Array.isArray(l) ? l : [l];
- E(u.length === r.length, function() {
+ });
+ return res.value;
+ }, inputMap, function(dy, saved) {
+ var gradRes = res.gradFunc(dy, saved);
+ var grads2 = Array.isArray(gradRes) ? gradRes : [gradRes];
+ assert(grads2.length === inputs.length, function() {
return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...).";
- }), E(u.every(function(d) {
- return d instanceof K;
+ });
+ assert(grads2.every(function(t) {
+ return t instanceof Tensor;
}), function() {
return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors.";
});
- var h = {};
- return u.forEach(function(d, p) {
- h[p] = function() {
- return d;
+ var gradMap = {};
+ grads2.forEach(function(grad2, i) {
+ gradMap[i] = function() {
+ return grad2;
};
- }), h;
+ });
+ return gradMap;
});
};
- }, n.prototype.readSync = function(t) {
- var e = this.state.tensorInfo.get(t);
- return e.backend.readSync(t);
- }, n.prototype.read = function(t) {
- var e = this.state.tensorInfo.get(t);
- return e.backend.read(t);
- }, n.prototype.time = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r;
- return fe(this, function(i) {
- switch (i.label) {
+ };
+ Engine2.prototype.readSync = function(dataId) {
+ var info = this.state.tensorInfo.get(dataId);
+ return info.backend.readSync(dataId);
+ };
+ Engine2.prototype.read = function(dataId) {
+ var info = this.state.tensorInfo.get(dataId);
+ return info.backend.read(dataId);
+ };
+ Engine2.prototype.time = function(query) {
+ return __awaiter(this, void 0, void 0, function() {
+ var start, timingInfo;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- return e = pu(), [4, this.backend.time(t)];
+ start = now2();
+ return [4, this.backend.time(query)];
case 1:
- return r = i.sent(), r.wallMs = pu() - e, [2, r];
+ timingInfo = _a.sent();
+ timingInfo.wallMs = now2() - start;
+ return [2, timingInfo];
}
});
});
- }, n.prototype.track = function(t) {
- return this.state.activeScope != null && (t.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(t)), t;
- }, Object.defineProperty(n.prototype, "registeredVariables", {get: function() {
- return this.state.registeredVariables;
- }, enumerable: true, configurable: true}), n.prototype.reset = function() {
- this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new _m();
- for (var t in this.registry)
- this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t];
- this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
- }, n.nextTensorId = 0, n.nextVariableId = 0, n;
+ };
+ Engine2.prototype.track = function(result) {
+ if (this.state.activeScope != null) {
+ result.scopeId = this.state.activeScope.id;
+ this.state.activeScope.track.push(result);
+ }
+ return result;
+ };
+ Object.defineProperty(Engine2.prototype, "registeredVariables", {
+ get: function() {
+ return this.state.registeredVariables;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Engine2.prototype.reset = function() {
+ this.pendingBackendInitId++;
+ this.state.dispose();
+ this.ENV.reset();
+ this.state = new EngineState();
+ for (var backendName in this.registry) {
+ this.disposeRegisteredKernels(backendName);
+ this.registry[backendName].dispose();
+ delete this.registry[backendName];
+ }
+ this.backendName = null;
+ this.backendInstance = null;
+ this.pendingBackendInit = null;
+ };
+ Engine2.nextTensorId = 0;
+ Engine2.nextVariableId = 0;
+ return Engine2;
}();
- function eS(n) {
- var t = yc(pt(n), "float32");
- return z.makeTensor(t, n, "float32");
+ function ones(shape) {
+ var values = makeOnesTypedArray(sizeFromShape(shape), "float32");
+ return ENGINE.makeTensor(values, shape, "float32");
}
- function Cm() {
- var n = Sf();
- if (n._tfengine == null) {
- var t = new Lf(n);
- n._tfengine = new nS(t);
+ function getOrMakeEngine() {
+ var ns = getGlobalNamespace();
+ if (ns._tfengine == null) {
+ var environment = new Environment(ns);
+ ns._tfengine = new Engine(environment);
}
- return RL(n._tfengine.ENV), KL(function() {
- return n._tfengine;
- }), n._tfengine;
+ setEnvironmentGlobal(ns._tfengine.ENV);
+ setTensorTracker(function() {
+ return ns._tfengine;
+ });
+ return ns._tfengine;
}
- var z = Cm();
- function tS(n, t) {
- var e = {a: n, b: t};
- return z.runKernelFunc(function(r, i) {
- var a = r.add(n, t);
- return i([n, t]), a;
- }, e, null, Cs);
+ var ENGINE = getOrMakeEngine();
+ function add(a, b) {
+ var inputs = {a, b};
+ return ENGINE.runKernelFunc(function(backend2, save) {
+ var res = backend2.add(a, b);
+ save([a, b]);
+ return res;
+ }, inputs, null, Add);
}
- function rS() {
- return typeof navigator != "undefined" && navigator != null;
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function _isNavigatorDefined() {
+ return typeof navigator !== "undefined" && navigator != null;
}
- function iS() {
- if (rS()) {
- var n = navigator.userAgent || navigator.vendor || window.opera;
- return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(n) || /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(n.substr(0, 4));
+ function isMobile() {
+ if (_isNavigatorDefined()) {
+ var a = navigator.userAgent || navigator.vendor || window.opera;
+ return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(a) || /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(a.substr(0, 4));
}
return false;
}
- function Rm() {
- return typeof window != "undefined" && window.document != null || typeof WorkerGlobalScope != "undefined";
+ function isBrowser() {
+ return typeof window !== "undefined" && window.document != null || typeof WorkerGlobalScope !== "undefined";
}
- var aS = {__proto__: null, isMobile: iS, isBrowser: Rm};
- var Yn = Ge();
- Yn.registerFlag("DEBUG", function() {
+ var device_util = {
+ __proto__: null,
+ isMobile,
+ isBrowser
+ };
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var ENV = env();
+ ENV.registerFlag("DEBUG", function() {
return false;
- }, function(n) {
- n && console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.");
+ }, function(debugValue) {
+ if (debugValue) {
+ console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.");
+ }
});
- Yn.registerFlag("IS_BROWSER", function() {
- return Rm();
+ ENV.registerFlag("IS_BROWSER", function() {
+ return isBrowser();
});
- Yn.registerFlag("IS_NODE", function() {
- return typeof process != "undefined" && typeof process.versions != "undefined" && typeof process.versions.node != "undefined";
+ ENV.registerFlag("IS_NODE", function() {
+ return typeof process !== "undefined" && typeof process.versions !== "undefined" && typeof process.versions.node !== "undefined";
});
- Yn.registerFlag("IS_CHROME", function() {
- return typeof navigator != "undefined" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor);
+ ENV.registerFlag("IS_CHROME", function() {
+ return typeof navigator !== "undefined" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor);
});
- Yn.registerFlag("PROD", function() {
+ ENV.registerFlag("PROD", function() {
return false;
});
- Yn.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", function() {
- return Yn.getBool("DEBUG");
+ ENV.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", function() {
+ return ENV.getBool("DEBUG");
});
- Yn.registerFlag("DEPRECATION_WARNINGS_ENABLED", function() {
+ ENV.registerFlag("DEPRECATION_WARNINGS_ENABLED", function() {
return true;
});
- Yn.registerFlag("IS_TEST", function() {
+ ENV.registerFlag("IS_TEST", function() {
return false;
});
- function Rn(n, t) {
- var e = n;
- if (Ft(n))
- return t === "string" ? [] : [n.length];
- if (!Array.isArray(n))
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function inferShape(val, dtype) {
+ var firstElem = val;
+ if (isTypedArray(val)) {
+ return dtype === "string" ? [] : [val.length];
+ }
+ if (!Array.isArray(val)) {
return [];
- for (var r = []; Array.isArray(e) || Ft(e) && t !== "string"; )
- r.push(e.length), e = e[0];
- return Array.isArray(n) && Ge().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") && Om(n, r, []), r;
+ }
+ var shape = [];
+ while (Array.isArray(firstElem) || isTypedArray(firstElem) && dtype !== "string") {
+ shape.push(firstElem.length);
+ firstElem = firstElem[0];
+ }
+ if (Array.isArray(val) && env().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY")) {
+ deepAssertShapeConsistency(val, shape, []);
+ }
+ return shape;
}
- function Om(n, t, e) {
- if (e = e || [], !Array.isArray(n) && !Ft(n)) {
- E(t.length === 0, function() {
- return "Element arr[" + e.join("][") + "] is a primitive, " + ("but should be an array/TypedArray of " + t[0] + " elements");
+ function deepAssertShapeConsistency(val, shape, indices) {
+ indices = indices || [];
+ if (!Array.isArray(val) && !isTypedArray(val)) {
+ assert(shape.length === 0, function() {
+ return "Element arr[" + indices.join("][") + "] is a primitive, " + ("but should be an array/TypedArray of " + shape[0] + " elements");
});
return;
}
- E(t.length > 0, function() {
- return "Element arr[" + e.join("][") + "] should be a primitive, " + ("but is an array of " + n.length + " elements");
- }), E(n.length === t[0], function() {
- return "Element arr[" + e.join("][") + "] should have " + t[0] + " " + ("elements, but has " + n.length + " elements");
+ assert(shape.length > 0, function() {
+ return "Element arr[" + indices.join("][") + "] should be a primitive, " + ("but is an array of " + val.length + " elements");
});
- for (var r = t.slice(1), i = 0; i < n.length; ++i)
- Om(n[i], r, e.concat(i));
- }
- function Em(n, t, e, r) {
- if (n == null)
- return;
- if (n !== "numeric" && n !== t || n === "numeric" && t === "string")
- throw new Error("Argument '" + e + "' passed to '" + r + "' must " + ("be " + n + " tensor, but got " + t + " tensor"));
- }
- function R(n, t, e, r) {
- if (r === void 0 && (r = "numeric"), n instanceof K)
- return Em(r, n.dtype, t, e), n;
- var i = Ns(n);
- if (i !== "string" && ["bool", "int32", "float32"].indexOf(r) >= 0 && (i = r), Em(r, i, t, e), n == null || !Ft(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string") {
- var a = n == null ? "null" : n.constructor.name;
- throw new Error("Argument '" + t + "' passed to '" + e + "' must be a " + ("Tensor or TensorLike, but got '" + a + "'"));
+ assert(val.length === shape[0], function() {
+ return "Element arr[" + indices.join("][") + "] should have " + shape[0] + " " + ("elements, but has " + val.length + " elements");
+ });
+ var subShape = shape.slice(1);
+ for (var i = 0; i < val.length; ++i) {
+ deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
}
- var s = Rn(n, i);
- !Ft(n) && !Array.isArray(n) && (n = [n]);
- var o = true, c = i !== "string" ? Es(n, i) : Br(n, [], o);
- return z.makeTensor(c, s, i);
}
- function Sa(n, t, e, r) {
- if (r === void 0 && (r = "numeric"), !Array.isArray(n))
- throw new Error("Argument " + t + " passed to " + e + " must be a `Tensor[]` or `TensorLike[]`");
- var i = n;
- return i.map(function(a, s) {
- return R(a, t + "[" + s + "]", e);
- }, r);
+ function assertDtype(expectedDtype, actualDType, argName, functionName) {
+ if (expectedDtype == null) {
+ return;
+ }
+ if (expectedDtype !== "numeric" && expectedDtype !== actualDType || expectedDtype === "numeric" && actualDType === "string") {
+ throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor"));
+ }
}
- var Dm = "__op";
- function U(n) {
- var t = Object.keys(n);
- if (t.length !== 1)
- throw new Error("Please provide an object with a single key (operation name) mapping to a function. Got an object with " + (t.length + " keys."));
- var e = t[0], r = n[e];
- e.endsWith("_") && (e = e.substring(0, e.length - 1)), e = e + Dm;
- var i = function() {
- for (var a = [], s = 0; s < arguments.length; s++)
- a[s] = arguments[s];
- z.startScope(e);
+ function convertToTensor(x, argName, functionName, parseAsDtype) {
+ if (parseAsDtype === void 0) {
+ parseAsDtype = "numeric";
+ }
+ if (x instanceof Tensor) {
+ assertDtype(parseAsDtype, x.dtype, argName, functionName);
+ return x;
+ }
+ var inferredDtype = inferDtype(x);
+ if (inferredDtype !== "string" && ["bool", "int32", "float32"].indexOf(parseAsDtype) >= 0) {
+ inferredDtype = parseAsDtype;
+ }
+ assertDtype(parseAsDtype, inferredDtype, argName, functionName);
+ if (x == null || !isTypedArray(x) && !Array.isArray(x) && typeof x !== "number" && typeof x !== "boolean" && typeof x !== "string") {
+ var type = x == null ? "null" : x.constructor.name;
+ throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + ("Tensor or TensorLike, but got '" + type + "'"));
+ }
+ var inferredShape = inferShape(x, inferredDtype);
+ if (!isTypedArray(x) && !Array.isArray(x)) {
+ x = [x];
+ }
+ var skipTypedArray = true;
+ var values = inferredDtype !== "string" ? toTypedArray(x, inferredDtype) : flatten(x, [], skipTypedArray);
+ return ENGINE.makeTensor(values, inferredShape, inferredDtype);
+ }
+ function convertToTensorArray(arg, argName, functionName, parseAsDtype) {
+ if (parseAsDtype === void 0) {
+ parseAsDtype = "numeric";
+ }
+ if (!Array.isArray(arg)) {
+ throw new Error("Argument " + argName + " passed to " + functionName + " must be a `Tensor[]` or `TensorLike[]`");
+ }
+ var tensors = arg;
+ return tensors.map(function(t, i) {
+ return convertToTensor(t, argName + "[" + i + "]", functionName);
+ }, parseAsDtype);
+ }
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var OP_SCOPE_SUFFIX = "__op";
+ function op(f) {
+ var keys = Object.keys(f);
+ if (keys.length !== 1) {
+ throw new Error("Please provide an object with a single key (operation name) mapping to a function. Got an object with " + (keys.length + " keys."));
+ }
+ var opName = keys[0];
+ var fn = f[opName];
+ if (opName.endsWith("_")) {
+ opName = opName.substring(0, opName.length - 1);
+ }
+ opName = opName + OP_SCOPE_SUFFIX;
+ var f2 = function() {
+ var args = [];
+ for (var _i2 = 0; _i2 < arguments.length; _i2++) {
+ args[_i2] = arguments[_i2];
+ }
+ ENGINE.startScope(opName);
try {
- var o = r.apply(void 0, a);
- return wc(o) && console.error("Cannot return a Promise inside of tidy."), z.endScope(o), o;
- } catch (c) {
- throw z.endScope(null), c;
+ var result = fn.apply(void 0, args);
+ if (isPromise(result)) {
+ console.error("Cannot return a Promise inside of tidy.");
+ }
+ ENGINE.endScope(result);
+ return result;
+ } catch (ex) {
+ ENGINE.endScope(null);
+ throw ex;
}
};
- return Object.defineProperty(i, "name", {value: e, configurable: true}), i;
+ Object.defineProperty(f2, "name", {value: opName, configurable: true});
+ return f2;
}
- function sS(n, t) {
- var e = R(n, "real", "complex"), r = R(t, "imag", "complex");
- Pe(e.shape, r.shape, "real and imag shapes, " + e.shape + " and " + r.shape + ", must match in call to tf.complex().");
- var i = function(s) {
- return s.complex(e, r);
- }, a = {real: e, imag: r};
- return z.runKernelFunc(i, a, null, Cf);
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function complex_(real2, imag2) {
+ var $real = convertToTensor(real2, "real", "complex");
+ var $imag = convertToTensor(imag2, "imag", "complex");
+ assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", must match in call to tf.complex().");
+ var forward = function(backend2) {
+ return backend2.complex($real, $imag);
+ };
+ var inputs = {real: $real, imag: $imag};
+ return ENGINE.runKernelFunc(forward, inputs, null, Complex);
}
- var lr = U({complex_: sS});
- function ur(n, t, e, r) {
- if (r == null && (r = Ns(n)), r === "complex64")
+ var complex = op({complex_});
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function makeTensor(values, shape, inferredShape, dtype) {
+ if (dtype == null) {
+ dtype = inferDtype(values);
+ }
+ if (dtype === "complex64") {
throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");
- if (!Ft(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string")
+ }
+ if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== "number" && typeof values !== "boolean" && typeof values !== "string") {
throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");
- if (t != null) {
- vc(t);
- var i = pt(t), a = pt(e);
- E(i === a, function() {
- return "Based on the provided shape, [" + t + "], the tensor should have " + (i + " values but has " + a);
+ }
+ if (shape != null) {
+ assertNonNegativeIntegerDimensions(shape);
+ var providedSize_1 = sizeFromShape(shape);
+ var inferredSize_1 = sizeFromShape(inferredShape);
+ assert(providedSize_1 === inferredSize_1, function() {
+ return "Based on the provided shape, [" + shape + "], the tensor should have " + (providedSize_1 + " values but has " + inferredSize_1);
});
- for (var s = 0; s < e.length; ++s) {
- var o = e[s], c = s === e.length - 1 ? o !== pt(t.slice(s)) : true;
- E(e[s] === t[s] || !c, function() {
- return "Error creating a new Tensor. Inferred shape " + ("(" + e + ") does not match the provided ") + ("shape (" + t + "). ");
+ for (var i = 0; i < inferredShape.length; ++i) {
+ var inferred = inferredShape[i];
+ var flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true;
+ assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function() {
+ return "Error creating a new Tensor. Inferred shape " + ("(" + inferredShape + ") does not match the provided ") + ("shape (" + shape + "). ");
});
}
}
- return !Ft(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = r !== "string" ? Es(n, r) : Br(n, [], true), z.makeTensor(n, t, r);
+ if (!isTypedArray(values) && !Array.isArray(values)) {
+ values = [values];
+ }
+ shape = shape || inferredShape;
+ values = dtype !== "string" ? toTypedArray(values, dtype) : flatten(values, [], true);
+ return ENGINE.makeTensor(values, shape, dtype);
}
- function hr(n, t, e) {
- var r = Rn(n, e);
- return ur(n, t, r, e);
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function tensor(values, shape, dtype) {
+ var inferredShape = inferShape(values, dtype);
+ return makeTensor(values, shape, inferredShape, dtype);
}
- var xu = {float32: 4, float16: 2, int32: 4, uint16: 2, uint8: 1, bool: 1, complex64: 8};
- var Ws = 4;
- function cS(n, t) {
- return pe(this, void 0, void 0, function() {
- var e, r, i, a, s, o, c = this;
- return fe(this, function(l) {
- switch (l.label) {
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var DTYPE_VALUE_SIZE_MAP = {
+ float32: 4,
+ float16: 2,
+ int32: 4,
+ uint16: 2,
+ uint8: 1,
+ bool: 1,
+ complex64: 8
+ };
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var NUM_BYTES_STRING_LENGTH = 4;
+ function encodeWeights(tensors, group) {
+ return __awaiter(this, void 0, void 0, function() {
+ var specs, dataPromises, names, _loop_1, i, tensorValues;
+ var _this = this;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- for (e = [], r = [], i = Array.isArray(n) ? n.map(function(u) {
- return u.name;
- }) : Object.keys(n), a = function(u) {
- var h = i[u], d = Array.isArray(n) ? n[u].tensor : n[h];
- if (d.dtype !== "float32" && d.dtype !== "int32" && d.dtype !== "bool" && d.dtype !== "string" && d.dtype !== "complex64")
- throw new Error("Unsupported dtype in weight '" + h + "': " + d.dtype);
- var p = {name: h, shape: d.shape, dtype: d.dtype};
- if (d.dtype === "string") {
- var f = new Promise(function(m) {
- return pe(c, void 0, void 0, function() {
- var g, y, w, b, L, x, N;
- return fe(this, function(I) {
- switch (I.label) {
+ specs = [];
+ dataPromises = [];
+ names = Array.isArray(tensors) ? tensors.map(function(tensor2) {
+ return tensor2.name;
+ }) : Object.keys(tensors);
+ _loop_1 = function(i2) {
+ var name_1 = names[i2];
+ var t = Array.isArray(tensors) ? tensors[i2].tensor : tensors[name_1];
+ if (t.dtype !== "float32" && t.dtype !== "int32" && t.dtype !== "bool" && t.dtype !== "string" && t.dtype !== "complex64") {
+ throw new Error("Unsupported dtype in weight '" + name_1 + "': " + t.dtype);
+ }
+ var spec = {name: name_1, shape: t.shape, dtype: t.dtype};
+ if (t.dtype === "string") {
+ var utf8bytes = new Promise(function(resolve) {
+ return __awaiter(_this, void 0, void 0, function() {
+ var vals, totalNumBytes, bytes, offset, i_1, val, bytesOfLength;
+ return __generator(this, function(_a2) {
+ switch (_a2.label) {
case 0:
- return [4, d.bytes()];
+ return [4, t.bytes()];
case 1:
- for (g = I.sent(), y = g.reduce(function(C, O) {
- return C + O.length;
- }, 0) + Ws * g.length, w = new Uint8Array(y), b = 0, L = 0; L < g.length; L++)
- x = g[L], N = new Uint8Array(new Uint32Array([x.length]).buffer), w.set(N, b), b += Ws, w.set(x, b), b += x.length;
- return m(w), [2];
+ vals = _a2.sent();
+ totalNumBytes = vals.reduce(function(p, c) {
+ return p + c.length;
+ }, 0) + NUM_BYTES_STRING_LENGTH * vals.length;
+ bytes = new Uint8Array(totalNumBytes);
+ offset = 0;
+ for (i_1 = 0; i_1 < vals.length; i_1++) {
+ val = vals[i_1];
+ bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
+ bytes.set(bytesOfLength, offset);
+ offset += NUM_BYTES_STRING_LENGTH;
+ bytes.set(val, offset);
+ offset += val.length;
+ }
+ resolve(bytes);
+ return [2];
}
});
});
});
- r.push(f);
- } else
- r.push(d.data());
- t != null && (p.group = t), e.push(p);
- }, s = 0; s < i.length; ++s)
- a(s);
- return [4, Promise.all(r)];
+ dataPromises.push(utf8bytes);
+ } else {
+ dataPromises.push(t.data());
+ }
+ if (group != null) {
+ spec.group = group;
+ }
+ specs.push(spec);
+ };
+ for (i = 0; i < names.length; ++i) {
+ _loop_1(i);
+ }
+ return [4, Promise.all(dataPromises)];
case 1:
- return o = l.sent(), [2, {data: oS(o), specs: e}];
+ tensorValues = _a.sent();
+ return [2, {data: concatenateTypedArrays(tensorValues), specs}];
}
});
});
}
- function km(n, t) {
- for (var e = {}, r, i = 0, a = 0, s = t; a < s.length; a++) {
- var o = s[a], c = o.name, l = o.dtype, u = o.shape, h = pt(u), d = void 0;
- if ("quantization" in o) {
- var p = o.quantization;
- if (p.dtype === "uint8" || p.dtype === "uint16") {
- if (!("min" in p && "scale" in p))
- throw new Error("Weight " + o.name + " with quantization " + p.dtype + " doesn't have corresponding metadata min and scale.");
- } else if (p.dtype === "float16") {
- if (l !== "float32")
- throw new Error("Weight " + o.name + " is quantized with " + p.dtype + " " + ("which only supports weights of type float32 not " + l + "."));
- } else
- throw new Error("Weight " + o.name + " has unknown " + ("quantization dtype " + p.dtype + ". ") + "Supported quantization dtypes are: 'uint8', 'uint16', and 'float16'.");
- var f = xu[p.dtype], m = n.slice(i, i + h * f), g = p.dtype === "uint8" ? new Uint8Array(m) : new Uint16Array(m);
- if (l === "float32")
- if (p.dtype === "uint8" || p.dtype === "uint16") {
- d = new Float32Array(g.length);
- for (var y = 0; y < g.length; y++) {
- var w = g[y];
- d[y] = w * p.scale + p.min;
- }
- } else if (p.dtype === "float16")
- r === void 0 && (r = lS()), d = r(g);
- else
- throw new Error("Unsupported quantization type " + p.dtype + " for weight type float32.");
- else if (l === "int32") {
- if (p.dtype !== "uint8" && p.dtype !== "uint16")
- throw new Error("Unsupported quantization type " + p.dtype + " for weight type int32.");
- d = new Int32Array(g.length);
- for (var y = 0; y < g.length; y++) {
- var w = g[y];
- d[y] = Math.round(w * p.scale + p.min);
+ function decodeWeights(buffer2, specs) {
+ var out = {};
+ var float16Decode;
+ var offset = 0;
+ for (var _i2 = 0, specs_1 = specs; _i2 < specs_1.length; _i2++) {
+ var spec = specs_1[_i2];
+ var name_2 = spec.name;
+ var dtype = spec.dtype;
+ var shape = spec.shape;
+ var size = sizeFromShape(shape);
+ var values = void 0;
+ if ("quantization" in spec) {
+ var quantization = spec.quantization;
+ if (quantization.dtype === "uint8" || quantization.dtype === "uint16") {
+ if (!("min" in quantization && "scale" in quantization)) {
+ throw new Error("Weight " + spec.name + " with quantization " + quantization.dtype + " doesn't have corresponding metadata min and scale.");
}
- } else
- throw new Error("Unsupported dtype in weight '" + c + "': " + l);
- i += h * f;
- } else if (l === "string") {
- var b = pt(o.shape);
- d = [];
- for (var y = 0; y < b; y++) {
- var L = new Uint32Array(n.slice(i, i + Ws))[0];
- i += Ws;
- var x = new Uint8Array(n.slice(i, i + L));
- d.push(x), i += L;
+ } else if (quantization.dtype === "float16") {
+ if (dtype !== "float32") {
+ throw new Error("Weight " + spec.name + " is quantized with " + quantization.dtype + " " + ("which only supports weights of type float32 not " + dtype + "."));
+ }
+ } else {
+ throw new Error("Weight " + spec.name + " has unknown " + ("quantization dtype " + quantization.dtype + ". ") + "Supported quantization dtypes are: 'uint8', 'uint16', and 'float16'.");
+ }
+ var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
+ var byteBuffer = buffer2.slice(offset, offset + size * quantizationSizeFactor);
+ var quantizedArray = quantization.dtype === "uint8" ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer);
+ if (dtype === "float32") {
+ if (quantization.dtype === "uint8" || quantization.dtype === "uint16") {
+ values = new Float32Array(quantizedArray.length);
+ for (var i = 0; i < quantizedArray.length; i++) {
+ var v = quantizedArray[i];
+ values[i] = v * quantization.scale + quantization.min;
+ }
+ } else if (quantization.dtype === "float16") {
+ if (float16Decode === void 0) {
+ float16Decode = getFloat16Decoder();
+ }
+ values = float16Decode(quantizedArray);
+ } else {
+ throw new Error("Unsupported quantization type " + quantization.dtype + " for weight type float32.");
+ }
+ } else if (dtype === "int32") {
+ if (quantization.dtype !== "uint8" && quantization.dtype !== "uint16") {
+ throw new Error("Unsupported quantization type " + quantization.dtype + " for weight type int32.");
+ }
+ values = new Int32Array(quantizedArray.length);
+ for (var i = 0; i < quantizedArray.length; i++) {
+ var v = quantizedArray[i];
+ values[i] = Math.round(v * quantization.scale + quantization.min);
+ }
+ } else {
+ throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype);
+ }
+ offset += size * quantizationSizeFactor;
+ } else if (dtype === "string") {
+ var size_1 = sizeFromShape(spec.shape);
+ values = [];
+ for (var i = 0; i < size_1; i++) {
+ var byteLength = new Uint32Array(buffer2.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
+ offset += NUM_BYTES_STRING_LENGTH;
+ var bytes = new Uint8Array(buffer2.slice(offset, offset + byteLength));
+ values.push(bytes);
+ offset += byteLength;
}
} else {
- var N = xu[l], m = n.slice(i, i + h * N);
- if (l === "float32")
- d = new Float32Array(m);
- else if (l === "int32")
- d = new Int32Array(m);
- else if (l === "bool")
- d = new Uint8Array(m);
- else if (l === "complex64") {
- d = new Float32Array(m);
- for (var I = new Float32Array(d.length / 2), C = new Float32Array(d.length / 2), y = 0; y < I.length; y++)
- I[y] = d[y * 2], C[y] = d[y * 2 + 1];
- var O = hr(I, u, "float32"), D = hr(C, u, "float32");
- e[c] = lr(O, D), O.dispose(), D.dispose();
- } else
- throw new Error("Unsupported dtype in weight '" + c + "': " + l);
- i += h * N;
+ var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
+ var byteBuffer = buffer2.slice(offset, offset + size * dtypeFactor);
+ if (dtype === "float32") {
+ values = new Float32Array(byteBuffer);
+ } else if (dtype === "int32") {
+ values = new Int32Array(byteBuffer);
+ } else if (dtype === "bool") {
+ values = new Uint8Array(byteBuffer);
+ } else if (dtype === "complex64") {
+ values = new Float32Array(byteBuffer);
+ var real2 = new Float32Array(values.length / 2);
+ var image3 = new Float32Array(values.length / 2);
+ for (var i = 0; i < real2.length; i++) {
+ real2[i] = values[i * 2];
+ image3[i] = values[i * 2 + 1];
+ }
+ var realTensor = tensor(real2, shape, "float32");
+ var imageTensor = tensor(image3, shape, "float32");
+ out[name_2] = complex(realTensor, imageTensor);
+ realTensor.dispose();
+ imageTensor.dispose();
+ } else {
+ throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype);
+ }
+ offset += size * dtypeFactor;
+ }
+ if (dtype !== "complex64") {
+ out[name_2] = tensor(values, shape, dtype);
}
- l !== "complex64" && (e[c] = hr(d, u, l));
}
- return e;
+ return out;
}
- function oS(n) {
- if (n === null)
- throw new Error("Invalid input value: " + JSON.stringify(n));
- var t = 0, e = [];
- n.forEach(function(a) {
- if (t += a.byteLength, e.push(a.byteLength === a.buffer.byteLength ? a : new a.constructor(a)), !(a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array))
- throw new Error("Unsupported TypedArray subtype: " + a.constructor.name);
- });
- var r = new Uint8Array(t), i = 0;
- return e.forEach(function(a) {
- r.set(new Uint8Array(a.buffer), i), i += a.byteLength;
- }), r.buffer;
- }
- var Lu = typeof Buffer != "undefined" && (typeof Blob == "undefined" || typeof atob == "undefined" || typeof btoa == "undefined");
- function Fm(n) {
- return Lu ? Buffer.byteLength(n) : new Blob([n]).size;
- }
- function uS(n) {
- if (Lu)
- return Buffer.from(n).toString("base64");
- for (var t = new Uint8Array(n), e = "", r = 0, i = t.length; r < i; r++)
- e += String.fromCharCode(t[r]);
- return btoa(e);
- }
- function hS(n) {
- if (Lu) {
- var t = Buffer.from(n, "base64");
- return t.buffer.slice(t.byteOffset, t.byteOffset + t.byteLength);
+ function concatenateTypedArrays(xs) {
+ if (xs === null) {
+ throw new Error("Invalid input value: " + JSON.stringify(xs));
}
- for (var e = atob(n), r = new Uint8Array(e.length), i = 0; i < e.length; ++i)
- r.set([e.charCodeAt(i)], i);
- return r.buffer;
- }
- function Su(n) {
- if (n.length === 1)
- return n[0];
- var t = 0;
- n.forEach(function(i) {
- t += i.byteLength;
+ var totalByteLength = 0;
+ var normalizedXs = [];
+ xs.forEach(function(x) {
+ totalByteLength += x.byteLength;
+ normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : new x.constructor(x));
+ if (!(x instanceof Float32Array || x instanceof Int32Array || x instanceof Uint8Array)) {
+ throw new Error("Unsupported TypedArray subtype: " + x.constructor.name);
+ }
});
- var e = new Uint8Array(t), r = 0;
- return n.forEach(function(i) {
- e.set(new Uint8Array(i), r), r += i.byteLength;
- }), e.buffer;
+ var y = new Uint8Array(totalByteLength);
+ var offset = 0;
+ normalizedXs.forEach(function(x) {
+ y.set(new Uint8Array(x.buffer), offset);
+ offset += x.byteLength;
+ });
+ return y.buffer;
}
- function Wm(n) {
- var t = "/";
- for (n = n.trim(); n.endsWith(t); )
- n = n.slice(0, n.length - 1);
- var e = n.split(t);
- return e[e.length - 1];
+ var useNodeBuffer = typeof Buffer !== "undefined" && (typeof Blob === "undefined" || typeof atob === "undefined" || typeof btoa === "undefined");
+ function stringByteLength(str2) {
+ if (useNodeBuffer) {
+ return Buffer.byteLength(str2);
+ }
+ return new Blob([str2]).size;
}
- function Ia(n) {
- if (n.modelTopology instanceof ArrayBuffer)
+ function arrayBufferToBase64String(buffer2) {
+ if (useNodeBuffer) {
+ return Buffer.from(buffer2).toString("base64");
+ }
+ var buf = new Uint8Array(buffer2);
+ var s = "";
+ for (var i = 0, l = buf.length; i < l; i++) {
+ s += String.fromCharCode(buf[i]);
+ }
+ return btoa(s);
+ }
+ function base64StringToArrayBuffer(str2) {
+ if (useNodeBuffer) {
+ var buf = Buffer.from(str2, "base64");
+ return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
+ }
+ var s = atob(str2);
+ var buffer2 = new Uint8Array(s.length);
+ for (var i = 0; i < s.length; ++i) {
+ buffer2.set([s.charCodeAt(i)], i);
+ }
+ return buffer2.buffer;
+ }
+ function concatenateArrayBuffers(buffers) {
+ if (buffers.length === 1) {
+ return buffers[0];
+ }
+ var totalByteLength = 0;
+ buffers.forEach(function(buffer2) {
+ totalByteLength += buffer2.byteLength;
+ });
+ var temp = new Uint8Array(totalByteLength);
+ var offset = 0;
+ buffers.forEach(function(buffer2) {
+ temp.set(new Uint8Array(buffer2), offset);
+ offset += buffer2.byteLength;
+ });
+ return temp.buffer;
+ }
+ function basename(path) {
+ var SEPARATOR = "/";
+ path = path.trim();
+ while (path.endsWith(SEPARATOR)) {
+ path = path.slice(0, path.length - 1);
+ }
+ var items = path.split(SEPARATOR);
+ return items[items.length - 1];
+ }
+ function getModelArtifactsInfoForJSON(modelArtifacts) {
+ if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error("Expected JSON model topology, received ArrayBuffer.");
- return {dateSaved: new Date(), modelTopologyType: "JSON", modelTopologyBytes: n.modelTopology == null ? 0 : Fm(JSON.stringify(n.modelTopology)), weightSpecsBytes: n.weightSpecs == null ? 0 : Fm(JSON.stringify(n.weightSpecs)), weightDataBytes: n.weightData == null ? 0 : n.weightData.byteLength};
- }
- function dS() {
- var n = function(r) {
- for (var i = r << 13, a = 0; (i & 8388608) === 0; )
- a -= 8388608, i <<= 1;
- return i &= ~8388608, a += 947912704, i | a;
- }, t = new Uint32Array(2048);
- t[0] = 0;
- for (var e = 1; e < 1024; e++)
- t[e] = n(e);
- for (var e = 1024; e < 2048; e++)
- t[e] = 939524096 + (e - 1024 << 13);
- return t;
- }
- function pS() {
- var n = new Uint32Array(64);
- n[0] = 0, n[31] = 1199570944, n[32] = 2147483648, n[63] = 3347054592;
- for (var t = 1; t < 31; t++)
- n[t] = t << 23;
- for (var t = 33; t < 63; t++)
- n[t] = 2147483648 + (t - 32 << 23);
- return n;
- }
- function fS() {
- for (var n = new Uint32Array(64), t = 0; t < 64; t++)
- n[t] = 1024;
- return n[0] = n[32] = 0, n;
- }
- function lS() {
- var n = dS(), t = pS(), e = fS();
- return function(r) {
- for (var i = new ArrayBuffer(4 * r.length), a = new Uint32Array(i), s = 0; s < r.length; s++) {
- var o = r[s], c = n[e[o >> 10] + (o & 1023)] + t[o >> 10];
- a[s] = c;
- }
- return new Float32Array(i);
+ }
+ return {
+ dateSaved: new Date(),
+ modelTopologyType: "JSON",
+ modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
+ weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
+ weightDataBytes: modelArtifacts.weightData == null ? 0 : modelArtifacts.weightData.byteLength
};
}
- var rn = function() {
- function n() {
- this.saveRouters = [], this.loadRouters = [];
+ function computeFloat16MantisaTable() {
+ var convertMantissa = function(i2) {
+ var m = i2 << 13;
+ var e = 0;
+ while ((m & 8388608) === 0) {
+ e -= 8388608;
+ m <<= 1;
+ }
+ m &= ~8388608;
+ e += 947912704;
+ return m | e;
+ };
+ var mantisaTable = new Uint32Array(2048);
+ mantisaTable[0] = 0;
+ for (var i = 1; i < 1024; i++) {
+ mantisaTable[i] = convertMantissa(i);
}
- return n.getInstance = function() {
- return n.instance == null && (n.instance = new n()), n.instance;
- }, n.registerSaveRouter = function(t) {
- n.getInstance().saveRouters.push(t);
- }, n.registerLoadRouter = function(t) {
- n.getInstance().loadRouters.push(t);
- }, n.getSaveHandlers = function(t) {
- return n.getHandlers(t, "save");
- }, n.getLoadHandlers = function(t, e) {
- return n.getHandlers(t, "load", e);
- }, n.getHandlers = function(t, e, r) {
- var i = [], a = e === "load" ? n.getInstance().loadRouters : n.getInstance().saveRouters;
- return a.forEach(function(s) {
- var o = s(t, r);
- o !== null && i.push(o);
- }), i;
- }, n;
- }(), mS = function(n) {
- return rn.registerSaveRouter(n);
- }, gS = function(n) {
- return rn.registerLoadRouter(n);
- }, yS = function(n) {
- return rn.getSaveHandlers(n);
- }, vS = function(n, t) {
- return rn.getLoadHandlers(n, t);
+ for (var i = 1024; i < 2048; i++) {
+ mantisaTable[i] = 939524096 + (i - 1024 << 13);
+ }
+ return mantisaTable;
+ }
+ function computeFloat16ExponentTable() {
+ var exponentTable = new Uint32Array(64);
+ exponentTable[0] = 0;
+ exponentTable[31] = 1199570944;
+ exponentTable[32] = 2147483648;
+ exponentTable[63] = 3347054592;
+ for (var i = 1; i < 31; i++) {
+ exponentTable[i] = i << 23;
+ }
+ for (var i = 33; i < 63; i++) {
+ exponentTable[i] = 2147483648 + (i - 32 << 23);
+ }
+ return exponentTable;
+ }
+ function computeFloat16OffsetTable() {
+ var offsetTable = new Uint32Array(64);
+ for (var i = 0; i < 64; i++) {
+ offsetTable[i] = 1024;
+ }
+ offsetTable[0] = offsetTable[32] = 0;
+ return offsetTable;
+ }
+ function getFloat16Decoder() {
+ var mantisaTable = computeFloat16MantisaTable();
+ var exponentTable = computeFloat16ExponentTable();
+ var offsetTable = computeFloat16OffsetTable();
+ return function(quantizedArray) {
+ var buffer2 = new ArrayBuffer(4 * quantizedArray.length);
+ var bufferUint32View = new Uint32Array(buffer2);
+ for (var index = 0; index < quantizedArray.length; index++) {
+ var float16Bits = quantizedArray[index];
+ var float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 1023)] + exponentTable[float16Bits >> 10];
+ bufferUint32View[index] = float32Bits;
+ }
+ return new Float32Array(buffer2);
+ };
+ }
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var IORouterRegistry = function() {
+ function IORouterRegistry2() {
+ this.saveRouters = [];
+ this.loadRouters = [];
+ }
+ IORouterRegistry2.getInstance = function() {
+ if (IORouterRegistry2.instance == null) {
+ IORouterRegistry2.instance = new IORouterRegistry2();
+ }
+ return IORouterRegistry2.instance;
+ };
+ IORouterRegistry2.registerSaveRouter = function(saveRouter) {
+ IORouterRegistry2.getInstance().saveRouters.push(saveRouter);
+ };
+ IORouterRegistry2.registerLoadRouter = function(loadRouter) {
+ IORouterRegistry2.getInstance().loadRouters.push(loadRouter);
+ };
+ IORouterRegistry2.getSaveHandlers = function(url) {
+ return IORouterRegistry2.getHandlers(url, "save");
+ };
+ IORouterRegistry2.getLoadHandlers = function(url, loadOptions) {
+ return IORouterRegistry2.getHandlers(url, "load", loadOptions);
+ };
+ IORouterRegistry2.getHandlers = function(url, handlerType, loadOptions) {
+ var validHandlers = [];
+ var routers = handlerType === "load" ? IORouterRegistry2.getInstance().loadRouters : IORouterRegistry2.getInstance().saveRouters;
+ routers.forEach(function(router) {
+ var handler = router(url, loadOptions);
+ if (handler !== null) {
+ validHandlers.push(handler);
+ }
+ });
+ return validHandlers;
+ };
+ return IORouterRegistry2;
+ }();
+ var registerSaveRouter = function(loudRouter) {
+ return IORouterRegistry.registerSaveRouter(loudRouter);
};
- var Iu = "tensorflowjs", Au = 1, zr = "models_store", dr = "model_info_store";
- function Um() {
- if (!Ge().getBool("IS_BROWSER"))
+ var registerLoadRouter = function(loudRouter) {
+ return IORouterRegistry.registerLoadRouter(loudRouter);
+ };
+ var getSaveHandlers = function(url) {
+ return IORouterRegistry.getSaveHandlers(url);
+ };
+ var getLoadHandlers = function(url, loadOptions) {
+ return IORouterRegistry.getLoadHandlers(url, loadOptions);
+ };
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var DATABASE_NAME = "tensorflowjs";
+ var DATABASE_VERSION = 1;
+ var MODEL_STORE_NAME = "models_store";
+ var INFO_STORE_NAME = "model_info_store";
+ function getIndexedDBFactory() {
+ if (!env().getBool("IS_BROWSER")) {
throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");
- var n = typeof window == "undefined" ? self : window, t = n.indexedDB || n.mozIndexedDB || n.webkitIndexedDB || n.msIndexedDB || n.shimIndexedDB;
- if (t == null)
- throw new Error("The current browser does not appear to support IndexedDB.");
- return t;
- }
- function Tu(n) {
- var t = n.result;
- t.createObjectStore(zr, {keyPath: "modelPath"}), t.createObjectStore(dr, {keyPath: "modelPath"});
- }
- var Ai = function() {
- function n(t) {
- if (this.indexedDB = Um(), t == null || !t)
- throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
- this.modelPath = t;
}
- return n.prototype.save = function(t) {
- return pe(this, void 0, void 0, function() {
- return fe(this, function(e) {
- if (t.modelTopology instanceof ArrayBuffer)
+ var theWindow = typeof window === "undefined" ? self : window;
+ var factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB;
+ if (factory == null) {
+ throw new Error("The current browser does not appear to support IndexedDB.");
+ }
+ return factory;
+ }
+ function setUpDatabase(openRequest) {
+ var db = openRequest.result;
+ db.createObjectStore(MODEL_STORE_NAME, {keyPath: "modelPath"});
+ db.createObjectStore(INFO_STORE_NAME, {keyPath: "modelPath"});
+ }
+ var BrowserIndexedDB = function() {
+ function BrowserIndexedDB2(modelPath) {
+ this.indexedDB = getIndexedDBFactory();
+ if (modelPath == null || !modelPath) {
+ throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
+ }
+ this.modelPath = modelPath;
+ }
+ BrowserIndexedDB2.prototype.save = function(modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function() {
+ return __generator(this, function(_a) {
+ if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
- return [2, this.databaseAction(this.modelPath, t)];
+ }
+ return [2, this.databaseAction(this.modelPath, modelArtifacts)];
});
});
- }, n.prototype.load = function() {
- return pe(this, void 0, void 0, function() {
- return fe(this, function(t) {
+ };
+ BrowserIndexedDB2.prototype.load = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ return __generator(this, function(_a) {
return [2, this.databaseAction(this.modelPath)];
});
});
- }, n.prototype.databaseAction = function(t, e) {
- var r = this;
- return new Promise(function(i, a) {
- var s = r.indexedDB.open(Iu, Au);
- s.onupgradeneeded = function() {
- return Tu(s);
- }, s.onsuccess = function() {
- var o = s.result;
- if (e == null) {
- var c = o.transaction(zr, "readonly"), l = c.objectStore(zr), u = l.get(r.modelPath);
- u.onsuccess = function() {
- if (u.result == null)
- return o.close(), a(new Error("Cannot find model with path '" + r.modelPath + "' in IndexedDB."));
- i(u.result.modelArtifacts);
- }, u.onerror = function(g) {
- return o.close(), a(u.error);
- }, c.oncomplete = function() {
- return o.close();
+ };
+ BrowserIndexedDB2.prototype.databaseAction = function(modelPath, modelArtifacts) {
+ var _this = this;
+ return new Promise(function(resolve, reject) {
+ var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
+ openRequest.onupgradeneeded = function() {
+ return setUpDatabase(openRequest);
+ };
+ openRequest.onsuccess = function() {
+ var db = openRequest.result;
+ if (modelArtifacts == null) {
+ var modelTx = db.transaction(MODEL_STORE_NAME, "readonly");
+ var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
+ var getRequest_1 = modelStore.get(_this.modelPath);
+ getRequest_1.onsuccess = function() {
+ if (getRequest_1.result == null) {
+ db.close();
+ return reject(new Error("Cannot find model with path '" + _this.modelPath + "' in IndexedDB."));
+ } else {
+ resolve(getRequest_1.result.modelArtifacts);
+ }
+ };
+ getRequest_1.onerror = function(error) {
+ db.close();
+ return reject(getRequest_1.error);
+ };
+ modelTx.oncomplete = function() {
+ return db.close();
};
} else {
- var h = Ia(e), d = o.transaction(dr, "readwrite"), p = d.objectStore(dr), f = p.put({modelPath: r.modelPath, modelArtifactsInfo: h}), m;
- f.onsuccess = function() {
- m = o.transaction(zr, "readwrite");
- var g = m.objectStore(zr), y = g.put({modelPath: r.modelPath, modelArtifacts: e, modelArtifactsInfo: h});
- y.onsuccess = function() {
- return i({modelArtifactsInfo: h});
- }, y.onerror = function(w) {
- p = d.objectStore(dr);
- var b = p.delete(r.modelPath);
- b.onsuccess = function() {
- return o.close(), a(y.error);
- }, b.onerror = function(L) {
- return o.close(), a(y.error);
+ var modelArtifactsInfo_1 = getModelArtifactsInfoForJSON(modelArtifacts);
+ var infoTx_1 = db.transaction(INFO_STORE_NAME, "readwrite");
+ var infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME);
+ var putInfoRequest_1 = infoStore_1.put({modelPath: _this.modelPath, modelArtifactsInfo: modelArtifactsInfo_1});
+ var modelTx_1;
+ putInfoRequest_1.onsuccess = function() {
+ modelTx_1 = db.transaction(MODEL_STORE_NAME, "readwrite");
+ var modelStore2 = modelTx_1.objectStore(MODEL_STORE_NAME);
+ var putModelRequest = modelStore2.put({
+ modelPath: _this.modelPath,
+ modelArtifacts,
+ modelArtifactsInfo: modelArtifactsInfo_1
+ });
+ putModelRequest.onsuccess = function() {
+ return resolve({modelArtifactsInfo: modelArtifactsInfo_1});
+ };
+ putModelRequest.onerror = function(error) {
+ infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME);
+ var deleteInfoRequest = infoStore_1.delete(_this.modelPath);
+ deleteInfoRequest.onsuccess = function() {
+ db.close();
+ return reject(putModelRequest.error);
+ };
+ deleteInfoRequest.onerror = function(error2) {
+ db.close();
+ return reject(putModelRequest.error);
};
};
- }, f.onerror = function(g) {
- return o.close(), a(f.error);
- }, d.oncomplete = function() {
- m == null ? o.close() : m.oncomplete = function() {
- return o.close();
- };
+ };
+ putInfoRequest_1.onerror = function(error) {
+ db.close();
+ return reject(putInfoRequest_1.error);
+ };
+ infoTx_1.oncomplete = function() {
+ if (modelTx_1 == null) {
+ db.close();
+ } else {
+ modelTx_1.oncomplete = function() {
+ return db.close();
+ };
+ }
};
}
- }, s.onerror = function(o) {
- return a(s.error);
+ };
+ openRequest.onerror = function(error) {
+ return reject(openRequest.error);
};
});
- }, n.URL_SCHEME = "indexeddb://", n;
- }(), Bm = function(n) {
- return Ge().getBool("IS_BROWSER") && (!Array.isArray(n) && n.startsWith(Ai.URL_SCHEME)) ? wS(n.slice(Ai.URL_SCHEME.length)) : null;
- };
- rn.registerSaveRouter(Bm);
- rn.registerLoadRouter(Bm);
- function wS(n) {
- return new Ai(n);
- }
- function bS(n) {
- return n.startsWith(Ai.URL_SCHEME) ? n.slice(Ai.URL_SCHEME.length) : n;
- }
- var xS = function() {
- function n() {
- this.indexedDB = Um();
- }
- return n.prototype.listModels = function() {
- return pe(this, void 0, void 0, function() {
- var t = this;
- return fe(this, function(e) {
- return [2, new Promise(function(r, i) {
- var a = t.indexedDB.open(Iu, Au);
- a.onupgradeneeded = function() {
- return Tu(a);
- }, a.onsuccess = function() {
- var s = a.result, o = s.transaction(dr, "readonly"), c = o.objectStore(dr), l = c.getAll();
- l.onsuccess = function() {
- for (var u = {}, h = 0, d = l.result; h < d.length; h++) {
- var p = d[h];
- u[p.modelPath] = p.modelArtifactsInfo;
- }
- r(u);
- }, l.onerror = function(u) {
- return s.close(), i(l.error);
- }, o.oncomplete = function() {
- return s.close();
- };
- }, a.onerror = function(s) {
- return i(a.error);
- };
- })];
- });
- });
- }, n.prototype.removeModel = function(t) {
- return pe(this, void 0, void 0, function() {
- var e = this;
- return fe(this, function(r) {
- return t = bS(t), [2, new Promise(function(i, a) {
- var s = e.indexedDB.open(Iu, Au);
- s.onupgradeneeded = function() {
- return Tu(s);
- }, s.onsuccess = function() {
- var o = s.result, c = o.transaction(dr, "readwrite"), l = c.objectStore(dr), u = l.get(t), h;
- u.onsuccess = function() {
- if (u.result == null)
- return o.close(), a(new Error("Cannot find model with path '" + t + "' in IndexedDB."));
- var d = l.delete(t), p = function() {
- h = o.transaction(zr, "readwrite");
- var f = h.objectStore(zr), m = f.delete(t);
- m.onsuccess = function() {
- return i(u.result.modelArtifactsInfo);
- }, m.onerror = function(g) {
- return a(u.error);
- };
- };
- d.onsuccess = p, d.onerror = function(f) {
- return p(), o.close(), a(u.error);
- };
- }, u.onerror = function(d) {
- return o.close(), a(u.error);
- }, c.oncomplete = function() {
- h == null ? o.close() : h.oncomplete = function() {
- return o.close();
- };
- };
- }, s.onerror = function(o) {
- return a(s.error);
- };
- })];
- });
- });
- }, n;
+ };
+ BrowserIndexedDB2.URL_SCHEME = "indexeddb://";
+ return BrowserIndexedDB2;
}();
- var Kn = "/", Ti = "tensorflowjs_models", zm = "info", LS = "model_topology", SS = "weight_specs", IS = "weight_data", AS = "model_metadata";
- function Pm(n) {
- return {info: [Ti, n, zm].join(Kn), topology: [Ti, n, LS].join(Kn), weightSpecs: [Ti, n, SS].join(Kn), weightData: [Ti, n, IS].join(Kn), modelMetadata: [Ti, n, AS].join(Kn)};
- }
- function TS(n) {
- var t = n.split(Kn);
- if (t.length < 3)
- throw new Error("Invalid key format: " + n);
- return t.slice(1, t.length - 1).join(Kn);
- }
- function NS(n) {
- return n.startsWith(Ni.URL_SCHEME) ? n.slice(Ni.URL_SCHEME.length) : n;
- }
- var Ni = function() {
- function n(t) {
- if (!Ge().getBool("IS_BROWSER") || typeof window == "undefined" || typeof window.localStorage == "undefined")
- throw new Error("The current environment does not support local storage.");
- if (this.LS = window.localStorage, t == null || !t)
- throw new Error("For local storage, modelPath must not be null, undefined or empty.");
- this.modelPath = t, this.keys = Pm(this.modelPath);
+ var indexedDBRouter = function(url) {
+ if (!env().getBool("IS_BROWSER")) {
+ return null;
+ } else {
+ if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
+ return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
+ } else {
+ return null;
+ }
}
- return n.prototype.save = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r, i;
- return fe(this, function(a) {
- if (t.modelTopology instanceof ArrayBuffer)
+ };
+ IORouterRegistry.registerSaveRouter(indexedDBRouter);
+ IORouterRegistry.registerLoadRouter(indexedDBRouter);
+ function browserIndexedDB(modelPath) {
+ return new BrowserIndexedDB(modelPath);
+ }
+ function maybeStripScheme(key) {
+ return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? key.slice(BrowserIndexedDB.URL_SCHEME.length) : key;
+ }
+ var BrowserIndexedDBManager = function() {
+ function BrowserIndexedDBManager2() {
+ this.indexedDB = getIndexedDBFactory();
+ }
+ BrowserIndexedDBManager2.prototype.listModels = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var _this = this;
+ return __generator(this, function(_a) {
+ return [2, new Promise(function(resolve, reject) {
+ var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
+ openRequest.onupgradeneeded = function() {
+ return setUpDatabase(openRequest);
+ };
+ openRequest.onsuccess = function() {
+ var db = openRequest.result;
+ var tx = db.transaction(INFO_STORE_NAME, "readonly");
+ var store = tx.objectStore(INFO_STORE_NAME);
+ var getAllInfoRequest = store.getAll();
+ getAllInfoRequest.onsuccess = function() {
+ var out = {};
+ for (var _i2 = 0, _a2 = getAllInfoRequest.result; _i2 < _a2.length; _i2++) {
+ var item = _a2[_i2];
+ out[item.modelPath] = item.modelArtifactsInfo;
+ }
+ resolve(out);
+ };
+ getAllInfoRequest.onerror = function(error) {
+ db.close();
+ return reject(getAllInfoRequest.error);
+ };
+ tx.oncomplete = function() {
+ return db.close();
+ };
+ };
+ openRequest.onerror = function(error) {
+ return reject(openRequest.error);
+ };
+ })];
+ });
+ });
+ };
+ BrowserIndexedDBManager2.prototype.removeModel = function(path) {
+ return __awaiter(this, void 0, void 0, function() {
+ var _this = this;
+ return __generator(this, function(_a) {
+ path = maybeStripScheme(path);
+ return [2, new Promise(function(resolve, reject) {
+ var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
+ openRequest.onupgradeneeded = function() {
+ return setUpDatabase(openRequest);
+ };
+ openRequest.onsuccess = function() {
+ var db = openRequest.result;
+ var infoTx = db.transaction(INFO_STORE_NAME, "readwrite");
+ var infoStore = infoTx.objectStore(INFO_STORE_NAME);
+ var getInfoRequest = infoStore.get(path);
+ var modelTx;
+ getInfoRequest.onsuccess = function() {
+ if (getInfoRequest.result == null) {
+ db.close();
+ return reject(new Error("Cannot find model with path '" + path + "' in IndexedDB."));
+ } else {
+ var deleteInfoRequest = infoStore.delete(path);
+ var deleteModelData_1 = function() {
+ modelTx = db.transaction(MODEL_STORE_NAME, "readwrite");
+ var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
+ var deleteModelRequest = modelStore.delete(path);
+ deleteModelRequest.onsuccess = function() {
+ return resolve(getInfoRequest.result.modelArtifactsInfo);
+ };
+ deleteModelRequest.onerror = function(error) {
+ return reject(getInfoRequest.error);
+ };
+ };
+ deleteInfoRequest.onsuccess = deleteModelData_1;
+ deleteInfoRequest.onerror = function(error) {
+ deleteModelData_1();
+ db.close();
+ return reject(getInfoRequest.error);
+ };
+ }
+ };
+ getInfoRequest.onerror = function(error) {
+ db.close();
+ return reject(getInfoRequest.error);
+ };
+ infoTx.oncomplete = function() {
+ if (modelTx == null) {
+ db.close();
+ } else {
+ modelTx.oncomplete = function() {
+ return db.close();
+ };
+ }
+ };
+ };
+ openRequest.onerror = function(error) {
+ return reject(openRequest.error);
+ };
+ })];
+ });
+ });
+ };
+ return BrowserIndexedDBManager2;
+ }();
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var PATH_SEPARATOR = "/";
+ var PATH_PREFIX = "tensorflowjs_models";
+ var INFO_SUFFIX = "info";
+ var MODEL_TOPOLOGY_SUFFIX = "model_topology";
+ var WEIGHT_SPECS_SUFFIX = "weight_specs";
+ var WEIGHT_DATA_SUFFIX = "weight_data";
+ var MODEL_METADATA_SUFFIX = "model_metadata";
+ function getModelKeys(path) {
+ return {
+ info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
+ topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
+ weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
+ weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
+ modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
+ };
+ }
+ function getModelPathFromKey(key) {
+ var items = key.split(PATH_SEPARATOR);
+ if (items.length < 3) {
+ throw new Error("Invalid key format: " + key);
+ }
+ return items.slice(1, items.length - 1).join(PATH_SEPARATOR);
+ }
+ function maybeStripScheme$1(key) {
+ return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? key.slice(BrowserLocalStorage.URL_SCHEME.length) : key;
+ }
+ var BrowserLocalStorage = function() {
+ function BrowserLocalStorage2(modelPath) {
+ if (!env().getBool("IS_BROWSER") || typeof window === "undefined" || typeof window.localStorage === "undefined") {
+ throw new Error("The current environment does not support local storage.");
+ }
+ this.LS = window.localStorage;
+ if (modelPath == null || !modelPath) {
+ throw new Error("For local storage, modelPath must not be null, undefined or empty.");
+ }
+ this.modelPath = modelPath;
+ this.keys = getModelKeys(this.modelPath);
+ }
+ BrowserLocalStorage2.prototype.save = function(modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function() {
+ var topology, weightSpecs, modelArtifactsInfo;
+ return __generator(this, function(_a) {
+ if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
- e = JSON.stringify(t.modelTopology), r = JSON.stringify(t.weightSpecs), i = Ia(t);
- try {
- return this.LS.setItem(this.keys.info, JSON.stringify(i)), this.LS.setItem(this.keys.topology, e), this.LS.setItem(this.keys.weightSpecs, r), this.LS.setItem(this.keys.weightData, uS(t.weightData)), this.LS.setItem(this.keys.modelMetadata, JSON.stringify({format: t.format, generatedBy: t.generatedBy, convertedBy: t.convertedBy, userDefinedMetadata: t.userDefinedMetadata})), [2, {modelArtifactsInfo: i}];
- } catch (s) {
- throw this.LS.removeItem(this.keys.info), this.LS.removeItem(this.keys.topology), this.LS.removeItem(this.keys.weightSpecs), this.LS.removeItem(this.keys.weightData), this.LS.removeItem(this.keys.modelMetadata), new Error("Failed to save model '" + this.modelPath + "' to local storage: size quota being exceeded is a possible cause of this failure: " + ("modelTopologyBytes=" + i.modelTopologyBytes + ", ") + ("weightSpecsBytes=" + i.weightSpecsBytes + ", ") + ("weightDataBytes=" + i.weightDataBytes + "."));
+ } else {
+ topology = JSON.stringify(modelArtifacts.modelTopology);
+ weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
+ modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
+ try {
+ this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
+ this.LS.setItem(this.keys.topology, topology);
+ this.LS.setItem(this.keys.weightSpecs, weightSpecs);
+ this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData));
+ this.LS.setItem(this.keys.modelMetadata, JSON.stringify({
+ format: modelArtifacts.format,
+ generatedBy: modelArtifacts.generatedBy,
+ convertedBy: modelArtifacts.convertedBy,
+ userDefinedMetadata: modelArtifacts.userDefinedMetadata
+ }));
+ return [2, {modelArtifactsInfo}];
+ } catch (err) {
+ this.LS.removeItem(this.keys.info);
+ this.LS.removeItem(this.keys.topology);
+ this.LS.removeItem(this.keys.weightSpecs);
+ this.LS.removeItem(this.keys.weightData);
+ this.LS.removeItem(this.keys.modelMetadata);
+ throw new Error("Failed to save model '" + this.modelPath + "' to local storage: size quota being exceeded is a possible cause of this failure: " + ("modelTopologyBytes=" + modelArtifactsInfo.modelTopologyBytes + ", ") + ("weightSpecsBytes=" + modelArtifactsInfo.weightSpecsBytes + ", ") + ("weightDataBytes=" + modelArtifactsInfo.weightDataBytes + "."));
+ }
}
return [2];
});
});
- }, n.prototype.load = function() {
- return pe(this, void 0, void 0, function() {
- var t, e, r, i, a, s, o;
- return fe(this, function(c) {
- if (t = JSON.parse(this.LS.getItem(this.keys.info)), t == null)
+ };
+ BrowserLocalStorage2.prototype.load = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64;
+ return __generator(this, function(_a) {
+ info = JSON.parse(this.LS.getItem(this.keys.info));
+ if (info == null) {
throw new Error("In local storage, there is no model with name '" + this.modelPath + "'");
- if (t.modelTopologyType !== "JSON")
+ }
+ if (info.modelTopologyType !== "JSON") {
throw new Error("BrowserLocalStorage does not support loading non-JSON model topology yet.");
- if (e = {}, r = JSON.parse(this.LS.getItem(this.keys.topology)), r == null)
+ }
+ out = {};
+ topology = JSON.parse(this.LS.getItem(this.keys.topology));
+ if (topology == null) {
throw new Error("In local storage, the topology of model '" + this.modelPath + "' is missing.");
- if (e.modelTopology = r, i = JSON.parse(this.LS.getItem(this.keys.weightSpecs)), i == null)
+ }
+ out.modelTopology = topology;
+ weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
+ if (weightSpecs == null) {
throw new Error("In local storage, the weight specs of model '" + this.modelPath + "' are missing.");
- if (e.weightSpecs = i, a = this.LS.getItem(this.keys.modelMetadata), a != null && (s = JSON.parse(a), e.format = s.format, e.generatedBy = s.generatedBy, e.convertedBy = s.convertedBy, e.userDefinedMetadata = s.userDefinedMetadata), o = this.LS.getItem(this.keys.weightData), o == null)
+ }
+ out.weightSpecs = weightSpecs;
+ metadataString = this.LS.getItem(this.keys.modelMetadata);
+ if (metadataString != null) {
+ metadata = JSON.parse(metadataString);
+ out.format = metadata["format"];
+ out.generatedBy = metadata["generatedBy"];
+ out.convertedBy = metadata["convertedBy"];
+ out.userDefinedMetadata = metadata["userDefinedMetadata"];
+ }
+ weightDataBase64 = this.LS.getItem(this.keys.weightData);
+ if (weightDataBase64 == null) {
throw new Error("In local storage, the binary weight values of model " + ("'" + this.modelPath + "' are missing."));
- return e.weightData = hS(o), [2, e];
+ }
+ out.weightData = base64StringToArrayBuffer(weightDataBase64);
+ return [2, out];
});
});
- }, n.URL_SCHEME = "localstorage://", n;
- }(), Mm = function(n) {
- return Ge().getBool("IS_BROWSER") && (!Array.isArray(n) && n.startsWith(Ni.URL_SCHEME)) ? _S(n.slice(Ni.URL_SCHEME.length)) : null;
- };
- rn.registerSaveRouter(Mm);
- rn.registerLoadRouter(Mm);
- function _S(n) {
- return new Ni(n);
- }
- var CS = function() {
- function n() {
- E(Ge().getBool("IS_BROWSER"), function() {
- return "Current environment is not a web browser";
- }), E(typeof window == "undefined" || typeof window.localStorage != "undefined", function() {
- return "Current browser does not appear to support localStorage";
- }), this.LS = window.localStorage;
- }
- return n.prototype.listModels = function() {
- return pe(this, void 0, void 0, function() {
- var t, e, r, i, a, s;
- return fe(this, function(o) {
- for (t = {}, e = Ti + Kn, r = Kn + zm, i = 0; i < this.LS.length; ++i)
- a = this.LS.key(i), a.startsWith(e) && a.endsWith(r) && (s = TS(a), t[s] = JSON.parse(this.LS.getItem(a)));
- return [2, t];
- });
- });
- }, n.prototype.removeModel = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r;
- return fe(this, function(i) {
- if (t = NS(t), e = Pm(t), this.LS.getItem(e.info) == null)
- throw new Error("Cannot find model at path '" + t + "'");
- return r = JSON.parse(this.LS.getItem(e.info)), this.LS.removeItem(e.info), this.LS.removeItem(e.topology), this.LS.removeItem(e.weightSpecs), this.LS.removeItem(e.weightData), [2, r];
- });
- });
- }, n;
+ };
+ BrowserLocalStorage2.URL_SCHEME = "localstorage://";
+ return BrowserLocalStorage2;
}();
- var _i = "://", pr = function() {
- function n() {
+ var localStorageRouter = function(url) {
+ if (!env().getBool("IS_BROWSER")) {
+ return null;
+ } else {
+ if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
+ return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
+ } else {
+ return null;
+ }
+ }
+ };
+ IORouterRegistry.registerSaveRouter(localStorageRouter);
+ IORouterRegistry.registerLoadRouter(localStorageRouter);
+ function browserLocalStorage(modelPath) {
+ return new BrowserLocalStorage(modelPath);
+ }
+ var BrowserLocalStorageManager = function() {
+ function BrowserLocalStorageManager2() {
+ assert(env().getBool("IS_BROWSER"), function() {
+ return "Current environment is not a web browser";
+ });
+ assert(typeof window === "undefined" || typeof window.localStorage !== "undefined", function() {
+ return "Current browser does not appear to support localStorage";
+ });
+ this.LS = window.localStorage;
+ }
+ BrowserLocalStorageManager2.prototype.listModels = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var out, prefix, suffix, i, key, modelPath;
+ return __generator(this, function(_a) {
+ out = {};
+ prefix = PATH_PREFIX + PATH_SEPARATOR;
+ suffix = PATH_SEPARATOR + INFO_SUFFIX;
+ for (i = 0; i < this.LS.length; ++i) {
+ key = this.LS.key(i);
+ if (key.startsWith(prefix) && key.endsWith(suffix)) {
+ modelPath = getModelPathFromKey(key);
+ out[modelPath] = JSON.parse(this.LS.getItem(key));
+ }
+ }
+ return [2, out];
+ });
+ });
+ };
+ BrowserLocalStorageManager2.prototype.removeModel = function(path) {
+ return __awaiter(this, void 0, void 0, function() {
+ var keys, info;
+ return __generator(this, function(_a) {
+ path = maybeStripScheme$1(path);
+ keys = getModelKeys(path);
+ if (this.LS.getItem(keys.info) == null) {
+ throw new Error("Cannot find model at path '" + path + "'");
+ }
+ info = JSON.parse(this.LS.getItem(keys.info));
+ this.LS.removeItem(keys.info);
+ this.LS.removeItem(keys.topology);
+ this.LS.removeItem(keys.weightSpecs);
+ this.LS.removeItem(keys.weightData);
+ return [2, info];
+ });
+ });
+ };
+ return BrowserLocalStorageManager2;
+ }();
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var URL_SCHEME_SUFFIX = "://";
+ var ModelStoreManagerRegistry = function() {
+ function ModelStoreManagerRegistry2() {
this.managers = {};
}
- return n.getInstance = function() {
- return n.instance == null && (n.instance = new n()), n.instance;
- }, n.registerManager = function(t, e) {
- E(t != null, function() {
+ ModelStoreManagerRegistry2.getInstance = function() {
+ if (ModelStoreManagerRegistry2.instance == null) {
+ ModelStoreManagerRegistry2.instance = new ModelStoreManagerRegistry2();
+ }
+ return ModelStoreManagerRegistry2.instance;
+ };
+ ModelStoreManagerRegistry2.registerManager = function(scheme, manager) {
+ assert(scheme != null, function() {
return "scheme must not be undefined or null.";
- }), t.endsWith(_i) && (t = t.slice(0, t.indexOf(_i))), E(t.length > 0, function() {
+ });
+ if (scheme.endsWith(URL_SCHEME_SUFFIX)) {
+ scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));
+ }
+ assert(scheme.length > 0, function() {
return "scheme must not be an empty string.";
});
- var r = n.getInstance();
- E(r.managers[t] == null, function() {
- return "A model store manager is already registered for scheme '" + t + "'.";
- }), r.managers[t] = e;
- }, n.getManager = function(t) {
- var e = this.getInstance().managers[t];
- if (e == null)
- throw new Error("Cannot find model manager for scheme '" + t + "'");
- return e;
- }, n.getSchemes = function() {
+ var registry = ModelStoreManagerRegistry2.getInstance();
+ assert(registry.managers[scheme] == null, function() {
+ return "A model store manager is already registered for scheme '" + scheme + "'.";
+ });
+ registry.managers[scheme] = manager;
+ };
+ ModelStoreManagerRegistry2.getManager = function(scheme) {
+ var manager = this.getInstance().managers[scheme];
+ if (manager == null) {
+ throw new Error("Cannot find model manager for scheme '" + scheme + "'");
+ }
+ return manager;
+ };
+ ModelStoreManagerRegistry2.getSchemes = function() {
return Object.keys(this.getInstance().managers);
- }, n;
+ };
+ return ModelStoreManagerRegistry2;
}();
- function Us(n) {
- if (n.indexOf(_i) === -1)
- throw new Error("The url string provided does not contain a scheme. Supported schemes are: " + ("" + pr.getSchemes().join(",")));
- return {scheme: n.split(_i)[0], path: n.split(_i)[1]};
+ function parseURL(url) {
+ if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {
+ throw new Error("The url string provided does not contain a scheme. Supported schemes are: " + ("" + ModelStoreManagerRegistry.getSchemes().join(",")));
+ }
+ return {
+ scheme: url.split(URL_SCHEME_SUFFIX)[0],
+ path: url.split(URL_SCHEME_SUFFIX)[1]
+ };
}
- function Hm(n, t, e) {
- return e === void 0 && (e = false), pe(this, void 0, void 0, function() {
- var r, i, a, s, o, c, l, u, h;
- return fe(this, function(d) {
- switch (d.label) {
+ function cloneModelInternal(sourceURL, destURL, deleteSource) {
+ if (deleteSource === void 0) {
+ deleteSource = false;
+ }
+ return __awaiter(this, void 0, void 0, function() {
+ var loadHandlers, loadHandler, saveHandlers, saveHandler, sourceScheme, sourcePath, sameMedium, modelArtifacts, saveResult;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- return E(n !== t, function() {
- return "Old path and new path are the same: '" + n + "'";
- }), r = rn.getLoadHandlers(n), E(r.length > 0, function() {
- return "Copying failed because no load handler is found for source URL " + n + ".";
- }), E(r.length < 2, function() {
- return "Copying failed because more than one (" + r.length + ") " + ("load handlers for source URL " + n + ".");
- }), i = r[0], a = rn.getSaveHandlers(t), E(a.length > 0, function() {
- return "Copying failed because no save handler is found for destination " + ("URL " + t + ".");
- }), E(a.length < 2, function() {
- return "Copying failed because more than one (" + r.length + ") " + ("save handlers for destination URL " + t + ".");
- }), s = a[0], o = Us(n).scheme, c = Us(n).path, l = o === Us(n).scheme, [4, i.load()];
+ assert(sourceURL !== destURL, function() {
+ return "Old path and new path are the same: '" + sourceURL + "'";
+ });
+ loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);
+ assert(loadHandlers.length > 0, function() {
+ return "Copying failed because no load handler is found for source URL " + sourceURL + ".";
+ });
+ assert(loadHandlers.length < 2, function() {
+ return "Copying failed because more than one (" + loadHandlers.length + ") " + ("load handlers for source URL " + sourceURL + ".");
+ });
+ loadHandler = loadHandlers[0];
+ saveHandlers = IORouterRegistry.getSaveHandlers(destURL);
+ assert(saveHandlers.length > 0, function() {
+ return "Copying failed because no save handler is found for destination " + ("URL " + destURL + ".");
+ });
+ assert(saveHandlers.length < 2, function() {
+ return "Copying failed because more than one (" + loadHandlers.length + ") " + ("save handlers for destination URL " + destURL + ".");
+ });
+ saveHandler = saveHandlers[0];
+ sourceScheme = parseURL(sourceURL).scheme;
+ sourcePath = parseURL(sourceURL).path;
+ sameMedium = sourceScheme === parseURL(sourceURL).scheme;
+ return [4, loadHandler.load()];
case 1:
- return u = d.sent(), e && l ? [4, pr.getManager(o).removeModel(c)] : [3, 3];
+ modelArtifacts = _a.sent();
+ if (!(deleteSource && sameMedium))
+ return [3, 3];
+ return [4, ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath)];
case 2:
- d.sent(), d.label = 3;
+ _a.sent();
+ _a.label = 3;
case 3:
- return [4, s.save(u)];
+ return [4, saveHandler.save(modelArtifacts)];
case 4:
- return h = d.sent(), e && !l ? [4, pr.getManager(o).removeModel(c)] : [3, 6];
+ saveResult = _a.sent();
+ if (!(deleteSource && !sameMedium))
+ return [3, 6];
+ return [4, ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath)];
case 5:
- d.sent(), d.label = 6;
+ _a.sent();
+ _a.label = 6;
case 6:
- return [2, h.modelArtifactsInfo];
+ return [2, saveResult.modelArtifactsInfo];
}
});
});
}
- function RS() {
- return pe(this, void 0, void 0, function() {
- var n, t, e, r, i, a, s, o;
- return fe(this, function(c) {
- switch (c.label) {
+ function listModels() {
+ return __awaiter(this, void 0, void 0, function() {
+ var schemes, out, _i2, schemes_1, scheme, schemeOut, path, url;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- n = pr.getSchemes(), t = {}, e = 0, r = n, c.label = 1;
+ schemes = ModelStoreManagerRegistry.getSchemes();
+ out = {};
+ _i2 = 0, schemes_1 = schemes;
+ _a.label = 1;
case 1:
- return e < r.length ? (i = r[e], [4, pr.getManager(i).listModels()]) : [3, 4];
+ if (!(_i2 < schemes_1.length))
+ return [3, 4];
+ scheme = schemes_1[_i2];
+ return [4, ModelStoreManagerRegistry.getManager(scheme).listModels()];
case 2:
- a = c.sent();
- for (s in a)
- o = i + _i + s, t[o] = a[s];
- c.label = 3;
+ schemeOut = _a.sent();
+ for (path in schemeOut) {
+ url = scheme + URL_SCHEME_SUFFIX + path;
+ out[url] = schemeOut[path];
+ }
+ _a.label = 3;
case 3:
- return e++, [3, 1];
+ _i2++;
+ return [3, 1];
case 4:
- return [2, t];
+ return [2, out];
}
});
});
}
- function OS(n) {
- return pe(this, void 0, void 0, function() {
- var t, e;
- return fe(this, function(r) {
- return t = Us(n), e = pr.getManager(t.scheme), [2, e.removeModel(t.path)];
+ function removeModel(url) {
+ return __awaiter(this, void 0, void 0, function() {
+ var schemeAndPath, manager;
+ return __generator(this, function(_a) {
+ schemeAndPath = parseURL(url);
+ manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);
+ return [2, manager.removeModel(schemeAndPath.path)];
});
});
}
- function ES(n, t) {
- return pe(this, void 0, void 0, function() {
- var e;
- return fe(this, function(r) {
- return e = false, [2, Hm(n, t, e)];
+ function copyModel(sourceURL, destURL) {
+ return __awaiter(this, void 0, void 0, function() {
+ var deleteSource;
+ return __generator(this, function(_a) {
+ deleteSource = false;
+ return [2, cloneModelInternal(sourceURL, destURL, deleteSource)];
});
});
}
- function DS(n, t) {
- return pe(this, void 0, void 0, function() {
- var e;
- return fe(this, function(r) {
- return e = true, [2, Hm(n, t, e)];
+ function moveModel(sourceURL, destURL) {
+ return __awaiter(this, void 0, void 0, function() {
+ var deleteSource;
+ return __generator(this, function(_a) {
+ deleteSource = true;
+ return [2, cloneModelInternal(sourceURL, destURL, deleteSource)];
});
});
}
- var kS = function() {
- function n() {
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var PlatformBrowser = function() {
+ function PlatformBrowser2() {
}
- return n.prototype.fetch = function(t, e) {
- return fetch(t, e);
- }, n.prototype.now = function() {
+ PlatformBrowser2.prototype.fetch = function(path, init) {
+ return fetch(path, init);
+ };
+ PlatformBrowser2.prototype.now = function() {
return performance.now();
- }, n.prototype.encode = function(t, e) {
- if (e !== "utf-8" && e !== "utf8")
- throw new Error("Browser's encoder only supports utf-8, but got " + e);
- return this.textEncoder == null && (this.textEncoder = new TextEncoder()), this.textEncoder.encode(t);
- }, n.prototype.decode = function(t, e) {
- return new TextDecoder(e).decode(t);
- }, n;
+ };
+ PlatformBrowser2.prototype.encode = function(text, encoding) {
+ if (encoding !== "utf-8" && encoding !== "utf8") {
+ throw new Error("Browser's encoder only supports utf-8, but got " + encoding);
+ }
+ if (this.textEncoder == null) {
+ this.textEncoder = new TextEncoder();
+ }
+ return this.textEncoder.encode(text);
+ };
+ PlatformBrowser2.prototype.decode = function(bytes, encoding) {
+ return new TextDecoder(encoding).decode(bytes);
+ };
+ return PlatformBrowser2;
}();
- if (Ge().get("IS_BROWSER")) {
- Ge().setPlatform("browser", new kS());
+ if (env().get("IS_BROWSER")) {
+ env().setPlatform("browser", new PlatformBrowser());
try {
- pr.registerManager(Ni.URL_SCHEME, new CS());
- } catch (n) {
+ ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());
+ } catch (err) {
}
try {
- pr.registerManager(Ai.URL_SCHEME, new xS());
- } catch (n) {
+ ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());
+ } catch (err) {
}
}
- var FS = {importFetch: function() {
- return sf();
- }}, Nu, WS = function() {
- function n() {
- this.util = of(), this.textEncoder = new this.util.TextEncoder();
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var getNodeFetch = {
+ importFetch: function() {
+ return require_browser();
}
- return n.prototype.fetch = function(t, e) {
- return Ge().global.fetch != null ? Ge().global.fetch(t, e) : (Nu == null && (Nu = FS.importFetch()), Nu(t, e));
- }, n.prototype.now = function() {
- var t = process.hrtime();
- return t[0] * 1e3 + t[1] / 1e6;
- }, n.prototype.encode = function(t, e) {
- if (e !== "utf-8" && e !== "utf8")
- throw new Error("Node built-in encoder only supports utf-8, but got " + e);
- return this.textEncoder.encode(t);
- }, n.prototype.decode = function(t, e) {
- return t.length === 0 ? "" : new this.util.TextDecoder(e).decode(t);
- }, n;
+ };
+ var systemFetch;
+ var PlatformNode = function() {
+ function PlatformNode2() {
+ this.util = require_util();
+ this.textEncoder = new this.util.TextEncoder();
+ }
+ PlatformNode2.prototype.fetch = function(path, requestInits) {
+ if (env().global.fetch != null) {
+ return env().global.fetch(path, requestInits);
+ }
+ if (systemFetch == null) {
+ systemFetch = getNodeFetch.importFetch();
+ }
+ return systemFetch(path, requestInits);
+ };
+ PlatformNode2.prototype.now = function() {
+ var time2 = process.hrtime();
+ return time2[0] * 1e3 + time2[1] / 1e6;
+ };
+ PlatformNode2.prototype.encode = function(text, encoding) {
+ if (encoding !== "utf-8" && encoding !== "utf8") {
+ throw new Error("Node built-in encoder only supports utf-8, but got " + encoding);
+ }
+ return this.textEncoder.encode(text);
+ };
+ PlatformNode2.prototype.decode = function(bytes, encoding) {
+ if (bytes.length === 0) {
+ return "";
+ }
+ return new this.util.TextDecoder(encoding).decode(bytes);
+ };
+ return PlatformNode2;
}();
- Ge().get("IS_NODE") && Ge().setPlatform("node", new WS());
- function On(n, t, e) {
- return t === void 0 && (t = "float32"), t = t || "float32", vc(n), new ks(n, t, e);
+ if (env().get("IS_NODE")) {
+ env().setPlatform("node", new PlatformNode());
}
- function US(n, t) {
- var e = R(n, "x", "cast");
- if (!ff(t))
- throw new Error("Failed to cast to unknown dtype " + t);
- if (t === "string" && e.dtype !== "string" || t !== "string" && e.dtype === "string")
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function buffer(shape, dtype, values) {
+ if (dtype === void 0) {
+ dtype = "float32";
+ }
+ dtype = dtype || "float32";
+ assertNonNegativeIntegerDimensions(shape);
+ return new TensorBuffer(shape, dtype, values);
+ }
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function cast_(x, dtype) {
+ var $x = convertToTensor(x, "x", "cast");
+ if (!isValidDtype(dtype)) {
+ throw new Error("Failed to cast to unknown dtype " + dtype);
+ }
+ if (dtype === "string" && $x.dtype !== "string" || dtype !== "string" && $x.dtype === "string") {
throw new Error("Only strings can be casted to strings");
- var r = {x: e}, i = {dtype: t};
- return z.runKernelFunc(function(a) {
- return a.cast(e, t);
- }, r, null, Rs, i);
- }
- var he = U({cast_: US});
- function BS(n) {
- var t = R(n, "x", "clone", null), e = function() {
- return z.makeTensorFromDataId(t.dataId, t.shape, t.dtype);
- }, r = {x: t};
- return z.runKernelFunc(e, r, null, il);
- }
- var Pr = U({clone_: BS});
- function Vm(n, t) {
- t === void 0 && (t = false), console.log(n.toString(t));
- }
- Cm();
- var zS = {buffer: On, cast: he, clone: Pr, print: Vm};
- jL(zS);
- var PS = "model", MS = ".json", HS = ".weights.bin";
- function Gm(n) {
- return new Promise(function(t) {
- return setTimeout(t);
- }).then(n);
- }
- var _u = function() {
- function n(t) {
- if (!Ge().getBool("IS_BROWSER"))
- throw new Error("browserDownloads() cannot proceed because the current environment is not a browser.");
- t.startsWith(n.URL_SCHEME) && (t = t.slice(n.URL_SCHEME.length)), (t == null || t.length === 0) && (t = PS), this.modelTopologyFileName = t + MS, this.weightDataFileName = t + HS;
}
- return n.prototype.save = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r, i, a, s, o;
- return fe(this, function(c) {
- switch (c.label) {
+ var inputs = {x: $x};
+ var attrs = {dtype};
+ return ENGINE.runKernelFunc(function(backend2) {
+ return backend2.cast($x, dtype);
+ }, inputs, null, Cast, attrs);
+ }
+ var cast = op({cast_});
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function clone_(x) {
+ var $x = convertToTensor(x, "x", "clone", null);
+ var forward = function() {
+ return ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype);
+ };
+ var inputs = {x: $x};
+ return ENGINE.runKernelFunc(forward, inputs, null, Identity);
+ }
+ var clone = op({clone_});
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function print(x, verbose) {
+ if (verbose === void 0) {
+ verbose = false;
+ }
+ console.log(x.toString(verbose));
+ }
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ getOrMakeEngine();
+ var opHandler$1 = {
+ buffer,
+ cast,
+ clone,
+ print
+ };
+ setOpHandler(opHandler$1);
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var DEFAULT_FILE_NAME_PREFIX = "model";
+ var DEFAULT_JSON_EXTENSION_NAME = ".json";
+ var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = ".weights.bin";
+ function defer(f) {
+ return new Promise(function(resolve) {
+ return setTimeout(resolve);
+ }).then(f);
+ }
+ var BrowserDownloads = function() {
+ function BrowserDownloads2(fileNamePrefix) {
+ if (!env().getBool("IS_BROWSER")) {
+ throw new Error("browserDownloads() cannot proceed because the current environment is not a browser.");
+ }
+ if (fileNamePrefix.startsWith(BrowserDownloads2.URL_SCHEME)) {
+ fileNamePrefix = fileNamePrefix.slice(BrowserDownloads2.URL_SCHEME.length);
+ }
+ if (fileNamePrefix == null || fileNamePrefix.length === 0) {
+ fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
+ }
+ this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
+ this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
+ }
+ BrowserDownloads2.prototype.save = function(modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function() {
+ var weightsURL, weightsManifest, modelTopologyAndWeightManifest, modelTopologyAndWeightManifestURL, jsonAnchor_1, weightDataAnchor_1;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- if (typeof document == "undefined")
+ if (typeof document === "undefined") {
throw new Error("Browser downloads are not supported in this environment since `document` is not present");
- if (e = window.URL.createObjectURL(new Blob([t.weightData], {type: "application/octet-stream"})), !(t.modelTopology instanceof ArrayBuffer))
+ }
+ weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], {type: "application/octet-stream"}));
+ if (!(modelArtifacts.modelTopology instanceof ArrayBuffer))
return [3, 1];
throw new Error("BrowserDownloads.save() does not support saving model topology in binary formats yet.");
case 1:
- return r = [{paths: ["./" + this.weightDataFileName], weights: t.weightSpecs}], i = {modelTopology: t.modelTopology, format: t.format, generatedBy: t.generatedBy, convertedBy: t.convertedBy, weightsManifest: r}, a = window.URL.createObjectURL(new Blob([JSON.stringify(i)], {type: "application/json"})), s = this.jsonAnchor == null ? document.createElement("a") : this.jsonAnchor, s.download = this.modelTopologyFileName, s.href = a, [4, Gm(function() {
- return s.dispatchEvent(new MouseEvent("click"));
+ weightsManifest = [{
+ paths: ["./" + this.weightDataFileName],
+ weights: modelArtifacts.weightSpecs
+ }];
+ modelTopologyAndWeightManifest = {
+ modelTopology: modelArtifacts.modelTopology,
+ format: modelArtifacts.format,
+ generatedBy: modelArtifacts.generatedBy,
+ convertedBy: modelArtifacts.convertedBy,
+ weightsManifest
+ };
+ modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {type: "application/json"}));
+ jsonAnchor_1 = this.jsonAnchor == null ? document.createElement("a") : this.jsonAnchor;
+ jsonAnchor_1.download = this.modelTopologyFileName;
+ jsonAnchor_1.href = modelTopologyAndWeightManifestURL;
+ return [4, defer(function() {
+ return jsonAnchor_1.dispatchEvent(new MouseEvent("click"));
})];
case 2:
- return c.sent(), t.weightData != null ? (o = this.weightDataAnchor == null ? document.createElement("a") : this.weightDataAnchor, o.download = this.weightDataFileName, o.href = e, [4, Gm(function() {
- return o.dispatchEvent(new MouseEvent("click"));
- })]) : [3, 4];
+ _a.sent();
+ if (!(modelArtifacts.weightData != null))
+ return [3, 4];
+ weightDataAnchor_1 = this.weightDataAnchor == null ? document.createElement("a") : this.weightDataAnchor;
+ weightDataAnchor_1.download = this.weightDataFileName;
+ weightDataAnchor_1.href = weightsURL;
+ return [4, defer(function() {
+ return weightDataAnchor_1.dispatchEvent(new MouseEvent("click"));
+ })];
case 3:
- c.sent(), c.label = 4;
+ _a.sent();
+ _a.label = 4;
case 4:
- return [2, {modelArtifactsInfo: Ia(t)}];
+ return [2, {modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts)}];
}
});
});
- }, n.URL_SCHEME = "downloads://", n;
- }(), VS = function() {
- function n(t) {
- if (t == null || t.length < 1)
- throw new Error("When calling browserFiles, at least 1 file is required, " + ("but received " + t));
- this.files = t;
+ };
+ BrowserDownloads2.URL_SCHEME = "downloads://";
+ return BrowserDownloads2;
+ }();
+ var BrowserFiles = function() {
+ function BrowserFiles2(files) {
+ if (files == null || files.length < 1) {
+ throw new Error("When calling browserFiles, at least 1 file is required, " + ("but received " + files));
+ }
+ this.files = files;
}
- return n.prototype.load = function() {
- return pe(this, void 0, void 0, function() {
- var t, e, r = this;
- return fe(this, function(i) {
- return t = this.files[0], e = this.files.slice(1), [2, new Promise(function(a, s) {
- var o = new FileReader();
- o.onload = function(c) {
- var l = JSON.parse(c.target.result), u = l.modelTopology;
- if (u == null) {
- s(new Error("modelTopology field is missing from file " + t.name));
+ BrowserFiles2.prototype.load = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var jsonFile, weightFiles;
+ var _this = this;
+ return __generator(this, function(_a) {
+ jsonFile = this.files[0];
+ weightFiles = this.files.slice(1);
+ return [2, new Promise(function(resolve, reject) {
+ var jsonReader = new FileReader();
+ jsonReader.onload = function(event) {
+ var modelJSON = JSON.parse(event.target.result);
+ var modelTopology = modelJSON.modelTopology;
+ if (modelTopology == null) {
+ reject(new Error("modelTopology field is missing from file " + jsonFile.name));
return;
}
- e.length === 0 && a({modelTopology: u});
- var h = l.weightsManifest;
- if (h == null) {
- s(new Error("weightManifest field is missing from file " + t.name));
+ if (weightFiles.length === 0) {
+ resolve({modelTopology});
+ }
+ var weightsManifest = modelJSON.weightsManifest;
+ if (weightsManifest == null) {
+ reject(new Error("weightManifest field is missing from file " + jsonFile.name));
return;
}
- var d;
+ var pathToFile;
try {
- d = r.checkManifestAndWeightFiles(h, e);
- } catch (g) {
- s(g);
+ pathToFile = _this.checkManifestAndWeightFiles(weightsManifest, weightFiles);
+ } catch (err) {
+ reject(err);
return;
}
- var p = [], f = [], m = [];
- h.forEach(function(g) {
- g.paths.forEach(function(y) {
- f.push(y), m.push(null);
- }), p.push.apply(p, g.weights);
- }), h.forEach(function(g) {
- g.paths.forEach(function(y) {
- var w = new FileReader();
- w.onload = function(b) {
- var L = b.target.result, x = f.indexOf(y);
- m[x] = L, m.indexOf(null) === -1 && a({modelTopology: u, weightSpecs: p, weightData: Su(m), format: l.format, generatedBy: l.generatedBy, convertedBy: l.convertedBy, userDefinedMetadata: l.userDefinedMetadata});
- }, w.onerror = function(b) {
- return s("Failed to weights data from file of path '" + y + "'.");
- }, w.readAsArrayBuffer(d[y]);
+ var weightSpecs = [];
+ var paths = [];
+ var perFileBuffers = [];
+ weightsManifest.forEach(function(weightsGroup) {
+ weightsGroup.paths.forEach(function(path) {
+ paths.push(path);
+ perFileBuffers.push(null);
+ });
+ weightSpecs.push.apply(weightSpecs, weightsGroup.weights);
+ });
+ weightsManifest.forEach(function(weightsGroup) {
+ weightsGroup.paths.forEach(function(path) {
+ var weightFileReader = new FileReader();
+ weightFileReader.onload = function(event2) {
+ var weightData = event2.target.result;
+ var index = paths.indexOf(path);
+ perFileBuffers[index] = weightData;
+ if (perFileBuffers.indexOf(null) === -1) {
+ resolve({
+ modelTopology,
+ weightSpecs,
+ weightData: concatenateArrayBuffers(perFileBuffers),
+ format: modelJSON.format,
+ generatedBy: modelJSON.generatedBy,
+ convertedBy: modelJSON.convertedBy,
+ userDefinedMetadata: modelJSON.userDefinedMetadata
+ });
+ }
+ };
+ weightFileReader.onerror = function(error) {
+ return reject("Failed to weights data from file of path '" + path + "'.");
+ };
+ weightFileReader.readAsArrayBuffer(pathToFile[path]);
});
});
- }, o.onerror = function(c) {
- return s("Failed to read model topology and weights manifest JSON " + ("from file '" + t.name + "'. BrowserFiles supports loading ") + "Keras-style tf.Model artifacts only.");
- }, o.readAsText(t);
+ };
+ jsonReader.onerror = function(error) {
+ return reject("Failed to read model topology and weights manifest JSON " + ("from file '" + jsonFile.name + "'. BrowserFiles supports loading ") + "Keras-style tf.Model artifacts only.");
+ };
+ jsonReader.readAsText(jsonFile);
})];
});
});
- }, n.prototype.checkManifestAndWeightFiles = function(t, e) {
- for (var r = [], i = e.map(function(l) {
- return Wm(l.name);
- }), a = {}, s = 0, o = t; s < o.length; s++) {
- var c = o[s];
- c.paths.forEach(function(l) {
- var u = Wm(l);
- if (r.indexOf(u) !== -1)
- throw new Error("Duplicate file basename found in weights manifest: " + ("'" + u + "'"));
- if (r.push(u), i.indexOf(u) === -1)
- throw new Error("Weight file with basename '" + u + "' is not provided.");
- a[l] = e[i.indexOf(u)];
+ };
+ BrowserFiles2.prototype.checkManifestAndWeightFiles = function(manifest, files) {
+ var basenames = [];
+ var fileNames = files.map(function(file) {
+ return basename(file.name);
+ });
+ var pathToFile = {};
+ for (var _i2 = 0, manifest_1 = manifest; _i2 < manifest_1.length; _i2++) {
+ var group = manifest_1[_i2];
+ group.paths.forEach(function(path) {
+ var pathBasename = basename(path);
+ if (basenames.indexOf(pathBasename) !== -1) {
+ throw new Error("Duplicate file basename found in weights manifest: " + ("'" + pathBasename + "'"));
+ }
+ basenames.push(pathBasename);
+ if (fileNames.indexOf(pathBasename) === -1) {
+ throw new Error("Weight file with basename '" + pathBasename + "' is not provided.");
+ } else {
+ pathToFile[path] = files[fileNames.indexOf(pathBasename)];
+ }
});
}
- if (r.length !== e.length)
- throw new Error("Mismatch in the number of files in weights manifest " + ("(" + r.length + ") and the number of weight files provided ") + ("(" + e.length + ")."));
- return a;
- }, n;
- }(), qS = function(n) {
- return Ge().getBool("IS_BROWSER") && (!Array.isArray(n) && n.startsWith(_u.URL_SCHEME)) ? GS(n.slice(_u.URL_SCHEME.length)) : null;
- };
- rn.registerSaveRouter(qS);
- function GS(n) {
- return n === void 0 && (n = "model"), new _u(n);
- }
- function YS(n) {
- return new VS(n);
- }
- function qm(n, t, e, r) {
- s(n), e = e == null ? 0 : e, r = r == null ? 1 : r, o(e, r);
- var i = 0, a = function(c) {
- return c.then(function(l) {
- var u = e + ++i / n.length * (r - e);
- return t(u), l;
- }), c;
+ if (basenames.length !== files.length) {
+ throw new Error("Mismatch in the number of files in weights manifest " + ("(" + basenames.length + ") and the number of weight files provided ") + ("(" + files.length + ")."));
+ }
+ return pathToFile;
};
- function s(c) {
- E(c != null && Array.isArray(c) && c.length > 0, function() {
+ return BrowserFiles2;
+ }();
+ var browserDownloadsRouter = function(url) {
+ if (!env().getBool("IS_BROWSER")) {
+ return null;
+ } else {
+ if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
+ return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
+ } else {
+ return null;
+ }
+ }
+ };
+ IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
+ function browserDownloads(fileNamePrefix) {
+ if (fileNamePrefix === void 0) {
+ fileNamePrefix = "model";
+ }
+ return new BrowserDownloads(fileNamePrefix);
+ }
+ function browserFiles(files) {
+ return new BrowserFiles(files);
+ }
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) {
+ checkPromises(promises);
+ startFraction = startFraction == null ? 0 : startFraction;
+ endFraction = endFraction == null ? 1 : endFraction;
+ checkFraction(startFraction, endFraction);
+ var resolvedPromise = 0;
+ var registerMonitor = function(promise) {
+ promise.then(function(value) {
+ var fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction);
+ onProgress(fraction);
+ return value;
+ });
+ return promise;
+ };
+ function checkPromises(promises2) {
+ assert(promises2 != null && Array.isArray(promises2) && promises2.length > 0, function() {
return "promises must be a none empty array";
});
}
- function o(c, l) {
- E(c >= 0 && c <= 1, function() {
- return "Progress fraction must be in range [0, 1], but " + ("got startFraction " + c);
- }), E(l >= 0 && l <= 1, function() {
- return "Progress fraction must be in range [0, 1], but " + ("got endFraction " + l);
- }), E(l >= c, function() {
- return "startFraction must be no more than endFraction, but " + ("got startFraction " + c + " and endFraction ") + ("" + l);
+ function checkFraction(startFraction2, endFraction2) {
+ assert(startFraction2 >= 0 && startFraction2 <= 1, function() {
+ return "Progress fraction must be in range [0, 1], but " + ("got startFraction " + startFraction2);
+ });
+ assert(endFraction2 >= 0 && endFraction2 <= 1, function() {
+ return "Progress fraction must be in range [0, 1], but " + ("got endFraction " + endFraction2);
+ });
+ assert(endFraction2 >= startFraction2, function() {
+ return "startFraction must be no more than endFraction, but " + ("got startFraction " + startFraction2 + " and endFraction ") + ("" + endFraction2);
});
}
- return Promise.all(n.map(a));
+ return Promise.all(promises.map(registerMonitor));
}
- function Ym(n, t) {
- return pe(this, void 0, void 0, function() {
- var e, r, i, a, s, o, c, l, u, h, d;
- return fe(this, function(p) {
- switch (p.label) {
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) {
+ return __awaiter(this, void 0, void 0, function() {
+ var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _a, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _b;
+ return __generator(this, function(_c) {
+ switch (_c.label) {
case 0:
- return t == null && (t = {}), e = t.fetchFunc == null ? Ge().platform.fetch : t.fetchFunc, r = n.map(function(f) {
- return e(f, t.requestInit, {isBinary: true});
- }), i = 0, a = 0.5, t.onProgress == null ? [4, Promise.all(r)] : [3, 2];
+ if (loadOptions == null) {
+ loadOptions = {};
+ }
+ fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc;
+ requests = fetchURLs.map(function(fetchURL) {
+ return fetchFunc(fetchURL, loadOptions.requestInit, {isBinary: true});
+ });
+ fetchStartFraction = 0;
+ fetchEndFraction = 0.5;
+ if (!(loadOptions.onProgress == null))
+ return [3, 2];
+ return [4, Promise.all(requests)];
case 1:
- return o = p.sent(), [3, 4];
+ _a = _c.sent();
+ return [3, 4];
case 2:
- return [4, qm(r, t.onProgress, i, a)];
+ return [4, monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)];
case 3:
- o = p.sent(), p.label = 4;
+ _a = _c.sent();
+ _c.label = 4;
case 4:
- return s = o, c = s.map(function(f) {
- return f.arrayBuffer();
- }), l = 0.5, u = 1, t.onProgress == null ? [4, Promise.all(c)] : [3, 6];
+ responses = _a;
+ bufferPromises = responses.map(function(response) {
+ return response.arrayBuffer();
+ });
+ bufferStartFraction = 0.5;
+ bufferEndFraction = 1;
+ if (!(loadOptions.onProgress == null))
+ return [3, 6];
+ return [4, Promise.all(bufferPromises)];
case 5:
- return d = p.sent(), [3, 8];
+ _b = _c.sent();
+ return [3, 8];
case 6:
- return [4, qm(c, t.onProgress, l, u)];
+ return [4, monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)];
case 7:
- d = p.sent(), p.label = 8;
+ _b = _c.sent();
+ _c.label = 8;
case 8:
- return h = d, [2, h];
+ buffers = _b;
+ return [2, buffers];
}
});
});
}
- function KS(n, t, e, r) {
- return t === void 0 && (t = ""), pe(this, void 0, void 0, function() {
- var i, a;
- return fe(this, function(s) {
- return i = function(o) {
- return Ym(o, {requestInit: r});
- }, a = Km(i), [2, a(n, t, e)];
+ function loadWeights(manifest, filePathPrefix, weightNames, requestInit) {
+ if (filePathPrefix === void 0) {
+ filePathPrefix = "";
+ }
+ return __awaiter(this, void 0, void 0, function() {
+ var fetchWeights, loadWeights2;
+ return __generator(this, function(_a) {
+ fetchWeights = function(fetchUrls) {
+ return loadWeightsAsArrayBuffer(fetchUrls, {requestInit});
+ };
+ loadWeights2 = weightsLoaderFactory(fetchWeights);
+ return [2, loadWeights2(manifest, filePathPrefix, weightNames)];
});
});
}
- function Km(n) {
- var t = this;
- return function(e, r, i) {
- return r === void 0 && (r = ""), pe(t, void 0, void 0, function() {
- var a, s, o, c, l, u, h, d, p, f;
- return fe(this, function(m) {
- switch (m.label) {
+ function weightsLoaderFactory(fetchWeightsFunction) {
+ var _this = this;
+ return function(manifest, filePathPrefix, weightNames) {
+ if (filePathPrefix === void 0) {
+ filePathPrefix = "";
+ }
+ return __awaiter(_this, void 0, void 0, function() {
+ var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- if (a = e.map(function() {
+ groupIndicesToFetchMap = manifest.map(function() {
return false;
- }), s = {}, o = i != null ? i.map(function() {
+ });
+ groupWeightsToFetch = {};
+ weightsFound = weightNames != null ? weightNames.map(function() {
return false;
- }) : [], c = [], e.forEach(function(g, y) {
- var w = 0;
- g.weights.forEach(function(b) {
- var L = "quantization" in b ? b.quantization.dtype : b.dtype, x = xu[L] * pt(b.shape), N = function() {
- a[y] = true, s[y] == null && (s[y] = []), s[y].push({manifestEntry: b, groupOffset: w, sizeBytes: x});
+ }) : [];
+ allManifestWeightNames = [];
+ manifest.forEach(function(manifestGroupConfig, groupIndex) {
+ var groupOffset = 0;
+ manifestGroupConfig.weights.forEach(function(weightsEntry) {
+ var rawDtype = "quantization" in weightsEntry ? weightsEntry.quantization.dtype : weightsEntry.dtype;
+ var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * sizeFromShape(weightsEntry.shape);
+ var enqueueWeightsForFetchingFn = function() {
+ groupIndicesToFetchMap[groupIndex] = true;
+ if (groupWeightsToFetch[groupIndex] == null) {
+ groupWeightsToFetch[groupIndex] = [];
+ }
+ groupWeightsToFetch[groupIndex].push({
+ manifestEntry: weightsEntry,
+ groupOffset,
+ sizeBytes: weightsBytes
+ });
};
- i != null ? i.forEach(function(I, C) {
- I === b.name && (N(), o[C] = true);
- }) : N(), c.push(b.name), w += x;
+ if (weightNames != null) {
+ weightNames.forEach(function(weightName, weightIndex) {
+ if (weightName === weightsEntry.name) {
+ enqueueWeightsForFetchingFn();
+ weightsFound[weightIndex] = true;
+ }
+ });
+ } else {
+ enqueueWeightsForFetchingFn();
+ }
+ allManifestWeightNames.push(weightsEntry.name);
+ groupOffset += weightsBytes;
});
- }), !o.every(function(g) {
- return g;
- }))
- throw l = i.filter(function(g, y) {
- return !o[y];
- }), new Error("Could not find weights in manifest with names: " + (l.join(", ") + `.
-`) + "Manifest JSON has weights with names: " + (c.join(", ") + "."));
- return u = a.reduce(function(g, y, w) {
- return y && g.push(w), g;
- }, []), h = [], u.forEach(function(g) {
- e[g].paths.forEach(function(y) {
- var w = r + (r.endsWith("/") ? "" : "/") + y;
- h.push(w);
+ });
+ if (!weightsFound.every(function(found) {
+ return found;
+ })) {
+ weightsNotFound = weightNames.filter(function(_, i) {
+ return !weightsFound[i];
});
- }), [4, n(h)];
- case 1:
- return d = m.sent(), p = {}, f = 0, u.forEach(function(g) {
- for (var y = e[g].paths.length, w = 0, b = 0; b < y; b++)
- w += d[f + b].byteLength;
- for (var L = new ArrayBuffer(w), x = new Uint8Array(L), N = 0, I = 0; I < y; I++) {
- var C = new Uint8Array(d[f + I]);
- x.set(C, N), N += C.byteLength;
+ throw new Error("Could not find weights in manifest with names: " + (weightsNotFound.join(", ") + ". \n") + "Manifest JSON has weights with names: " + (allManifestWeightNames.join(", ") + "."));
+ }
+ groupIndicesToFetch = groupIndicesToFetchMap.reduce(function(accumulator, shouldFetch, i) {
+ if (shouldFetch) {
+ accumulator.push(i);
}
- var O = s[g];
- O.forEach(function(D) {
- var F = L.slice(D.groupOffset, D.groupOffset + D.sizeBytes), k = km(F, [D.manifestEntry]);
- for (var B in k)
- p[B] = k[B];
- }), f += y;
- }), [2, p];
+ return accumulator;
+ }, []);
+ fetchUrls = [];
+ groupIndicesToFetch.forEach(function(i) {
+ manifest[i].paths.forEach(function(filepath) {
+ var fetchUrl = filePathPrefix + (!filePathPrefix.endsWith("/") ? "/" : "") + filepath;
+ fetchUrls.push(fetchUrl);
+ });
+ });
+ return [4, fetchWeightsFunction(fetchUrls)];
+ case 1:
+ buffers = _a.sent();
+ weightsTensorMap = {};
+ bufferIndexOffset = 0;
+ groupIndicesToFetch.forEach(function(i) {
+ var numBuffers = manifest[i].paths.length;
+ var groupBytes = 0;
+ for (var i_1 = 0; i_1 < numBuffers; i_1++) {
+ groupBytes += buffers[bufferIndexOffset + i_1].byteLength;
+ }
+ var groupBuffer = new ArrayBuffer(groupBytes);
+ var groupByteBuffer = new Uint8Array(groupBuffer);
+ var groupBufferOffset = 0;
+ for (var i_2 = 0; i_2 < numBuffers; i_2++) {
+ var buffer2 = new Uint8Array(buffers[bufferIndexOffset + i_2]);
+ groupByteBuffer.set(buffer2, groupBufferOffset);
+ groupBufferOffset += buffer2.byteLength;
+ }
+ var weightsEntries = groupWeightsToFetch[i];
+ weightsEntries.forEach(function(weightsEntry) {
+ var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
+ var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
+ for (var name_1 in nameToTensorMap) {
+ weightsTensorMap[name_1] = nameToTensorMap[name_1];
+ }
+ });
+ bufferIndexOffset += numBuffers;
+ });
+ return [2, weightsTensorMap];
}
});
});
};
}
- var jS = "application/octet-stream", $S = "application/json", jm = function() {
- function n(t, e) {
- if (this.DEFAULT_METHOD = "POST", e == null && (e = {}), this.weightPathPrefix = e.weightPathPrefix, this.onProgress = e.onProgress, this.weightUrlConverter = e.weightUrlConverter, e.fetchFunc != null ? (E(typeof e.fetchFunc == "function", function() {
- return "Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)";
- }), this.fetch = e.fetchFunc) : this.fetch = Ge().platform.fetch, E(t != null && t.length > 0, function() {
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var OCTET_STREAM_MIME_TYPE = "application/octet-stream";
+ var JSON_TYPE = "application/json";
+ var HTTPRequest = function() {
+ function HTTPRequest2(path, loadOptions) {
+ this.DEFAULT_METHOD = "POST";
+ if (loadOptions == null) {
+ loadOptions = {};
+ }
+ this.weightPathPrefix = loadOptions.weightPathPrefix;
+ this.onProgress = loadOptions.onProgress;
+ this.weightUrlConverter = loadOptions.weightUrlConverter;
+ if (loadOptions.fetchFunc != null) {
+ assert(typeof loadOptions.fetchFunc === "function", function() {
+ return "Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)";
+ });
+ this.fetch = loadOptions.fetchFunc;
+ } else {
+ this.fetch = env().platform.fetch;
+ }
+ assert(path != null && path.length > 0, function() {
return "URL path for http must not be null, undefined or empty.";
- }), Array.isArray(t) && E(t.length === 2, function() {
- return "URL paths for http must have a length of 2, " + ("(actual length is " + t.length + ").");
- }), this.path = t, e.requestInit != null && e.requestInit.body != null)
+ });
+ if (Array.isArray(path)) {
+ assert(path.length === 2, function() {
+ return "URL paths for http must have a length of 2, " + ("(actual length is " + path.length + ").");
+ });
+ }
+ this.path = path;
+ if (loadOptions.requestInit != null && loadOptions.requestInit.body != null) {
throw new Error("requestInit is expected to have no pre-existing body, but has one.");
- this.requestInit = e.requestInit || {};
+ }
+ this.requestInit = loadOptions.requestInit || {};
}
- return n.prototype.save = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r, i, a;
- return fe(this, function(s) {
- switch (s.label) {
+ HTTPRequest2.prototype.save = function(modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function() {
+ var init, weightsManifest, modelTopologyAndWeightManifest, response;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
- if (t.modelTopology instanceof ArrayBuffer)
+ if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error("BrowserHTTPRequest.save() does not support saving model topology in binary formats yet.");
- return e = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit), e.body = new FormData(), r = [{paths: ["./model.weights.bin"], weights: t.weightSpecs}], i = {modelTopology: t.modelTopology, format: t.format, generatedBy: t.generatedBy, convertedBy: t.convertedBy, userDefinedMetadata: t.userDefinedMetadata, weightsManifest: r}, e.body.append("model.json", new Blob([JSON.stringify(i)], {type: $S}), "model.json"), t.weightData != null && e.body.append("model.weights.bin", new Blob([t.weightData], {type: jS}), "model.weights.bin"), [4, this.fetch(this.path, e)];
+ }
+ init = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit);
+ init.body = new FormData();
+ weightsManifest = [{
+ paths: ["./model.weights.bin"],
+ weights: modelArtifacts.weightSpecs
+ }];
+ modelTopologyAndWeightManifest = {
+ modelTopology: modelArtifacts.modelTopology,
+ format: modelArtifacts.format,
+ generatedBy: modelArtifacts.generatedBy,
+ convertedBy: modelArtifacts.convertedBy,
+ userDefinedMetadata: modelArtifacts.userDefinedMetadata,
+ weightsManifest
+ };
+ init.body.append("model.json", new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {type: JSON_TYPE}), "model.json");
+ if (modelArtifacts.weightData != null) {
+ init.body.append("model.weights.bin", new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}), "model.weights.bin");
+ }
+ return [4, this.fetch(this.path, init)];
case 1:
- if (a = s.sent(), a.ok)
- return [2, {modelArtifactsInfo: Ia(t), responses: [a]}];
- throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + (a.status + "."));
+ response = _a.sent();
+ if (response.ok) {
+ return [2, {
+ modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),
+ responses: [response]
+ }];
+ } else {
+ throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + (response.status + "."));
+ }
}
});
});
- }, n.prototype.load = function() {
- return pe(this, void 0, void 0, function() {
- var t, e, r, i, a, s, o, c, l, u, h, d, p, f, m;
- return fe(this, function(g) {
- switch (g.label) {
+ };
+ HTTPRequest2.prototype.load = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ var modelConfigRequest, modelConfig, e_1, message, modelTopology, weightsManifest, generatedBy, convertedBy, format, userDefinedMetadata, weightSpecs, weightData, results, artifacts, initializer;
+ return __generator(this, function(_a) {
+ switch (_a.label) {
case 0:
return [4, this.fetch(this.path, this.requestInit)];
case 1:
- if (t = g.sent(), !t.ok)
- throw new Error("Request to " + this.path + " failed with status code " + (t.status + ". Please verify this URL points to ") + "the model JSON of the model to load.");
- g.label = 2;
+ modelConfigRequest = _a.sent();
+ if (!modelConfigRequest.ok) {
+ throw new Error("Request to " + this.path + " failed with status code " + (modelConfigRequest.status + ". Please verify this URL points to ") + "the model JSON of the model to load.");
+ }
+ _a.label = 2;
case 2:
- return g.trys.push([2, 4, , 5]), [4, t.json()];
+ _a.trys.push([2, 4, , 5]);
+ return [4, modelConfigRequest.json()];
case 3:
- return e = g.sent(), [3, 5];
+ modelConfig = _a.sent();
+ return [3, 5];
case 4:
- throw r = g.sent(), i = "Failed to parse model JSON of response from " + this.path + ".", this.path.endsWith(".pb") ? i += " Your path contains a .pb file extension. Support for .pb models have been removed in TensorFlow.js 1.0 in favor of .json models. You can re-convert your Python TensorFlow model using the TensorFlow.js 1.0 conversion scripts or you can convert your.pb models with the 'pb2json'NPM script in the tensorflow/tfjs-converter repository." : i += " Please make sure the server is serving valid JSON for this request.", new Error(i);
+ e_1 = _a.sent();
+ message = "Failed to parse model JSON of response from " + this.path + ".";
+ if (this.path.endsWith(".pb")) {
+ message += " Your path contains a .pb file extension. Support for .pb models have been removed in TensorFlow.js 1.0 in favor of .json models. You can re-convert your Python TensorFlow model using the TensorFlow.js 1.0 conversion scripts or you can convert your.pb models with the 'pb2json'NPM script in the tensorflow/tfjs-converter repository.";
+ } else {
+ message += " Please make sure the server is serving valid JSON for this request.";
+ }
+ throw new Error(message);
case 5:
- if (a = e.modelTopology, s = e.weightsManifest, o = e.generatedBy, c = e.convertedBy, l = e.format, u = e.userDefinedMetadata, a == null && s == null)
+ modelTopology = modelConfig.modelTopology;
+ weightsManifest = modelConfig.weightsManifest;
+ generatedBy = modelConfig.generatedBy;
+ convertedBy = modelConfig.convertedBy;
+ format = modelConfig.format;
+ userDefinedMetadata = modelConfig.userDefinedMetadata;
+ if (modelTopology == null && weightsManifest == null) {
throw new Error("The JSON from HTTP path " + this.path + " contains neither model topology or manifest for weights.");
- return s != null ? [4, this.loadWeights(s)] : [3, 7];
+ }
+ if (!(weightsManifest != null))
+ return [3, 7];
+ return [4, this.loadWeights(weightsManifest)];
case 6:
- p = g.sent(), h = p[0], d = p[1], g.label = 7;
+ results = _a.sent();
+ weightSpecs = results[0], weightData = results[1];
+ _a.label = 7;
case 7:
- return f = {modelTopology: a, weightSpecs: h, weightData: d, userDefinedMetadata: u, generatedBy: o, convertedBy: c, format: l}, m = e.modelInitializer, m && (f.modelInitializer = m), [2, f];
+ artifacts = {
+ modelTopology,
+ weightSpecs,
+ weightData,
+ userDefinedMetadata,
+ generatedBy,
+ convertedBy,
+ format
+ };
+ initializer = modelConfig.modelInitializer;
+ if (initializer) {
+ artifacts.modelInitializer = initializer;
+ }
+ return [2, artifacts];
}
});
});
- }, n.prototype.loadWeights = function(t) {
- return pe(this, void 0, void 0, function() {
- var e, r, i, a, s, o, c, l, u, h, d, p, f, m, g, y, w, b, L, x, N;
- return fe(this, function(I) {
- switch (I.label) {
+ };
+ HTTPRequest2.prototype.loadWeights = function(weightsManifest) {
+ return __awaiter(this, void 0, void 0, function() {
+ var weightPath, _a, prefix, suffix, pathPrefix, weightSpecs, _i2, weightsManifest_1, entry, fetchURLs, urlPromises, _b, weightsManifest_2, weightsGroup, _c, _d, path, _e, _f, _g, buffers;
+ return __generator(this, function(_h) {
+ switch (_h.label) {
case 0:
- for (e = Array.isArray(this.path) ? this.path[1] : this.path, r = XS(e), i = r[0], a = r[1], s = this.weightPathPrefix || i, o = [], c = 0, l = t; c < l.length; c++)
- u = l[c], o.push.apply(o, u.weights);
- for (h = [], d = [], p = 0, f = t; p < f.length; p++)
- for (m = f[p], g = 0, y = m.paths; g < y.length; g++)
- w = y[g], this.weightUrlConverter != null ? d.push(this.weightUrlConverter(w)) : h.push(s + w + a);
- return this.weightUrlConverter ? (L = (b = h.push).apply, x = [h], [4, Promise.all(d)]) : [3, 2];
+ weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
+ _a = parseUrl(weightPath), prefix = _a[0], suffix = _a[1];
+ pathPrefix = this.weightPathPrefix || prefix;
+ weightSpecs = [];
+ for (_i2 = 0, weightsManifest_1 = weightsManifest; _i2 < weightsManifest_1.length; _i2++) {
+ entry = weightsManifest_1[_i2];
+ weightSpecs.push.apply(weightSpecs, entry.weights);
+ }
+ fetchURLs = [];
+ urlPromises = [];
+ for (_b = 0, weightsManifest_2 = weightsManifest; _b < weightsManifest_2.length; _b++) {
+ weightsGroup = weightsManifest_2[_b];
+ for (_c = 0, _d = weightsGroup.paths; _c < _d.length; _c++) {
+ path = _d[_c];
+ if (this.weightUrlConverter != null) {
+ urlPromises.push(this.weightUrlConverter(path));
+ } else {
+ fetchURLs.push(pathPrefix + path + suffix);
+ }
+ }
+ }
+ if (!this.weightUrlConverter)
+ return [3, 2];
+ _f = (_e = fetchURLs.push).apply;
+ _g = [fetchURLs];
+ return [4, Promise.all(urlPromises)];
case 1:
- L.apply(b, x.concat([I.sent()])), I.label = 2;
+ _f.apply(_e, _g.concat([_h.sent()]));
+ _h.label = 2;
case 2:
- return [4, Ym(h, {requestInit: this.requestInit, fetchFunc: this.fetch, onProgress: this.onProgress})];
+ return [4, loadWeightsAsArrayBuffer(fetchURLs, {
+ requestInit: this.requestInit,
+ fetchFunc: this.fetch,
+ onProgress: this.onProgress
+ })];
case 3:
- return N = I.sent(), [2, [o, Su(N)]];
+ buffers = _h.sent();
+ return [2, [weightSpecs, concatenateArrayBuffers(buffers)]];
}
});
});
- }, n.URL_SCHEME_REGEX = /^https?:\/\//, n;
+ };
+ HTTPRequest2.URL_SCHEME_REGEX = /^https?:\/\//;
+ return HTTPRequest2;
}();
- function XS(n) {
- var t = n.lastIndexOf("/"), e = n.lastIndexOf("?"), r = n.substring(0, t), i = e > t ? n.substring(e) : "";
- return [r + "/", i];
+ function parseUrl(url) {
+ var lastSlash = url.lastIndexOf("/");
+ var lastSearchParam = url.lastIndexOf("?");
+ var prefix = url.substring(0, lastSlash);
+ var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : "";
+ return [prefix + "/", suffix];
}
- function Cu(n) {
- return n.match(jm.URL_SCHEME_REGEX) != null;
+ function isHTTPScheme(url) {
+ return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
}
- var $m = function(n, t) {
- if (typeof fetch == "undefined" && (t == null || t.fetchFunc == null))
+ var httpRouter = function(url, loadOptions) {
+ if (typeof fetch === "undefined" && (loadOptions == null || loadOptions.fetchFunc == null)) {
return null;
- var e = true;
- return Array.isArray(n) ? e = n.every(function(r) {
- return Cu(r);
- }) : e = Cu(n), e ? Ru(n, t) : null;
- };
- rn.registerSaveRouter($m);
- rn.registerLoadRouter($m);
- function Ru(n, t) {
- return new jm(n, t);
- }
- function JS(n, t) {
- return Ru(n, t);
- }
- var Ou = function() {
- function n(t) {
- this.modelArtifacts = t;
+ } else {
+ var isHTTP = true;
+ if (Array.isArray(url)) {
+ isHTTP = url.every(function(urlItem) {
+ return isHTTPScheme(urlItem);
+ });
+ } else {
+ isHTTP = isHTTPScheme(url);
+ }
+ if (isHTTP) {
+ return http(url, loadOptions);
+ }
}
- return n.prototype.load = function() {
- return pe(this, void 0, void 0, function() {
- return fe(this, function(t) {
+ return null;
+ };
+ IORouterRegistry.registerSaveRouter(httpRouter);
+ IORouterRegistry.registerLoadRouter(httpRouter);
+ function http(path, loadOptions) {
+ return new HTTPRequest(path, loadOptions);
+ }
+ function browserHTTPRequest(path, loadOptions) {
+ return http(path, loadOptions);
+ }
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var PassthroughLoader = function() {
+ function PassthroughLoader2(modelArtifacts) {
+ this.modelArtifacts = modelArtifacts;
+ }
+ PassthroughLoader2.prototype.load = function() {
+ return __awaiter(this, void 0, void 0, function() {
+ return __generator(this, function(_a) {
return [2, this.modelArtifacts];
});
});
- }, n;
- }(), ZS = function() {
- function n(t) {
- this.saveHandler = t;
+ };
+ return PassthroughLoader2;
+ }();
+ var PassthroughSaver = function() {
+ function PassthroughSaver2(saveHandler) {
+ this.saveHandler = saveHandler;
}
- return n.prototype.save = function(t) {
- return pe(this, void 0, void 0, function() {
- return fe(this, function(e) {
- return [2, this.saveHandler(t)];
+ PassthroughSaver2.prototype.save = function(modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function() {
+ return __generator(this, function(_a) {
+ return [2, this.saveHandler(modelArtifacts)];
});
});
- }, n;
- }();
- function QS(n, t, e, r) {
- if (arguments.length === 1) {
- var i = n.modelTopology != null || n.weightSpecs != null;
- return i ? new Ou(n) : (console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release."), new Ou({modelTopology: n}));
- } else
- return console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release."), new Ou({modelTopology: n, weightSpecs: t, weightData: e, trainingConfig: r});
- }
- function eI(n) {
- return new ZS(n);
- }
- var tI = {__proto__: null, browserFiles: YS, browserHTTPRequest: JS, concatenateArrayBuffers: Su, decodeWeights: km, encodeWeights: cS, fromMemory: QS, getLoadHandlers: vS, getModelArtifactsInfoForJSON: Ia, getSaveHandlers: yS, http: Ru, isHTTPScheme: Cu, loadWeights: KS, registerLoadRouter: gS, registerSaveRouter: mS, weightsLoaderFactory: Km, withSaveHandler: eI, copyModel: ES, listModels: RS, moveModel: DS, removeModel: OS};
- function nI(n, t) {
- var e = R(n, "x", "reshape", null), r = {x: e}, i = {shape: t}, a = function(s, o) {
- return t = uf(t, e.size), E(e.size === pt(t), function() {
- return "new shape and old shape must have the same number of elements.";
- }), o([e]), s.reshape(e, t);
};
- return z.runKernelFunc(a, r, null, Cl, i);
- }
- var Y = U({reshape_: nI});
- function rI(n, t, e, r) {
- var i;
- e === void 0 && (e = false), r === void 0 && (r = false);
- var a = R(n, "a", "matMul"), s = R(t, "b", "matMul");
- i = ct(a, s), a = i[0], s = i[1];
- var o = function(u, h) {
- h([a, s]);
- var d = e ? a.shape[a.rank - 2] : a.shape[a.rank - 1], p = r ? s.shape[s.rank - 1] : s.shape[s.rank - 2], f = e ? a.shape[a.rank - 1] : a.shape[a.rank - 2], m = r ? s.shape[s.rank - 2] : s.shape[s.rank - 1], g = a.shape.slice(0, -2), y = s.shape.slice(0, -2), w = pt(g), b = pt(y), L = w === b || w === 1 || b === 1;
- E(a.rank >= 2 && s.rank >= 2 && L, function() {
- return "Error in matMul: the input batch dimensions must either be the same or at least one input batch dimension must be 1. Got input " + ("batch dimensions of (" + g + ") and (" + y + ").");
- }), E(d === p, function() {
- return "Error in matMul: inner shapes (" + d + ") and (" + (p + ") of Tensors with shapes " + a.shape + " and ") + (s.shape + " and transposeA=" + e) + (" and transposeB=" + r + " must match.");
+ return PassthroughSaver2;
+ }();
+ function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
+ if (arguments.length === 1) {
+ var isModelArtifacts = modelArtifacts.modelTopology != null || modelArtifacts.weightSpecs != null;
+ if (isModelArtifacts) {
+ return new PassthroughLoader(modelArtifacts);
+ } else {
+ console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release.");
+ return new PassthroughLoader({modelTopology: modelArtifacts});
+ }
+ } else {
+ console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release.");
+ return new PassthroughLoader({
+ modelTopology: modelArtifacts,
+ weightSpecs,
+ weightData,
+ trainingConfig
});
- var x = w > b ? g : y, N = x.concat([f, m]), I = e ? Y(a, [w, d, f]) : Y(a, [w, f, d]), C = r ? Y(s, [b, m, p]) : Y(s, [b, p, m]), O = u.batchMatMul(I, C, e, r);
- return Y(O, N);
- }, c = {a, b: s}, l = {transposeA: e, transposeB: r};
- return z.runKernelFunc(o, c, null, kc, l);
+ }
}
- var Ue = U({matMul_: rI});
- function iI(n, t, e, r) {
- if (e === void 0 && (e = 1), r === void 0 && (r = 0), t < 2)
- throw new Error("Error in oneHot: depth must be >=2, but it is " + t);
- var i = R(n, "indices", "oneHot", "int32"), a = i.shape.concat([t]), s = function(l, u) {
- return u([i]), Y(l.oneHot(Y(i, [i.size]), t, e, r), a);
- }, o = {indices: i}, c = {depth: t, onValue: e, offValue: r};
- return z.runKernelFunc(s, o, null, Sl, c);
+ function withSaveHandler(saveHandler) {
+ return new PassthroughSaver(saveHandler);
}
- var Bs = U({oneHot_: iI});
- function aI(n, t) {
- var e = R(n, "x", "transpose");
- if (t == null && (t = e.shape.map(function(a, s) {
- return s;
- }).reverse()), E(e.rank === t.length, function() {
- return "Error in transpose: rank of input " + e.rank + " " + ("must match length of perm " + t + ".");
- }), t.forEach(function(a) {
- E(a >= 0 && a < e.rank, function() {
- return "All entries in 'perm' must be between 0 and " + (e.rank - 1) + (" but got " + t);
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var io = {
+ __proto__: null,
+ browserFiles,
+ browserHTTPRequest,
+ concatenateArrayBuffers,
+ decodeWeights,
+ encodeWeights,
+ fromMemory,
+ getLoadHandlers,
+ getModelArtifactsInfoForJSON,
+ getSaveHandlers,
+ http,
+ isHTTPScheme,
+ loadWeights,
+ registerLoadRouter,
+ registerSaveRouter,
+ weightsLoaderFactory,
+ withSaveHandler,
+ copyModel,
+ listModels,
+ moveModel,
+ removeModel
+ };
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function reshape_(x, shape) {
+ var $x = convertToTensor(x, "x", "reshape", null);
+ var inputs = {x: $x};
+ var attrs = {shape};
+ var forward = function(backend2, save) {
+ shape = inferFromImplicitShape(shape, $x.size);
+ assert($x.size === sizeFromShape(shape), function() {
+ return "new shape and old shape must have the same number of elements.";
});
- }), e.rank <= 1)
- return e.clone();
- var r = {x: e}, i = {perm: t};
- return z.runKernelFunc(function(a) {
- return a.transpose(e, t);
- }, r, null, eu, i);
+ save([$x]);
+ return backend2.reshape($x, shape);
+ };
+ return ENGINE.runKernelFunc(forward, inputs, null, Reshape, attrs);
}
- var Tt = U({transpose_: aI});
- function sI(n, t, e) {
- var r = R(n, "labels", "confusionMatrix"), i = R(t, "predictions", "confusionMatrix");
- E(e == null || e > 0 && Number.isInteger(e), function() {
- return "If provided, numClasses must be a positive integer, " + ("but got " + e);
- }), E(r.rank === 1, function() {
- return "Expected the rank of labels to be 1, but got " + r.rank;
- }), E(i.rank === 1, function() {
- return "Expected the rank of predictions to be 1, " + ("but got " + i.rank);
- }), E(r.shape[0] === i.shape[0], function() {
- return "Mismatch in the number of examples: " + (r.shape[0] + " vs. " + i.shape[0] + ". ") + "Labels and predictions should have the same number of elements.";
- }), E(e > 0 && Number.isInteger(e), function() {
- return "numClasses is required to be a positive integer, but got " + ("" + e);
+ var reshape = op({reshape_});
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function matMul_(a, b, transposeA, transposeB) {
+ var _a;
+ if (transposeA === void 0) {
+ transposeA = false;
+ }
+ if (transposeB === void 0) {
+ transposeB = false;
+ }
+ var $a = convertToTensor(a, "a", "matMul");
+ var $b = convertToTensor(b, "b", "matMul");
+ _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
+ var forward = function(backend2, save) {
+ save([$a, $b]);
+ var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
+ var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
+ var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
+ var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
+ var outerDimsA = $a.shape.slice(0, -2);
+ var outerDimsB = $b.shape.slice(0, -2);
+ var batchDimA = sizeFromShape(outerDimsA);
+ var batchDimB = sizeFromShape(outerDimsB);
+ var batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1;
+ assert($a.rank >= 2 && $b.rank >= 2 && batchDimsCompatible, function() {
+ return "Error in matMul: the input batch dimensions must either be the same or at least one input batch dimension must be 1. Got input " + ("batch dimensions of (" + outerDimsA + ") and (" + outerDimsB + ").");
+ });
+ assert(innerShapeA === innerShapeB, function() {
+ return "Error in matMul: inner shapes (" + innerShapeA + ") and (" + (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") + ($b.shape + " and transposeA=" + transposeA) + (" and transposeB=" + transposeB + " must match.");
+ });
+ var outShapeOuterDims = batchDimA > batchDimB ? outerDimsA : outerDimsB;
+ var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
+ var a3D = transposeA ? reshape($a, [batchDimA, innerShapeA, outerShapeA]) : reshape($a, [batchDimA, outerShapeA, innerShapeA]);
+ var b3D = transposeB ? reshape($b, [batchDimB, outerShapeB, innerShapeB]) : reshape($b, [batchDimB, innerShapeB, outerShapeB]);
+ var res3d = backend2.batchMatMul(a3D, b3D, transposeA, transposeB);
+ return reshape(res3d, outShape);
+ };
+ var inputs = {a: $a, b: $b};
+ var attrs = {transposeA, transposeB};
+ return ENGINE.runKernelFunc(forward, inputs, null, BatchMatMul, attrs);
+ }
+ var matMul = op({matMul_});
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function oneHot_(indices, depth, onValue, offValue) {
+ if (onValue === void 0) {
+ onValue = 1;
+ }
+ if (offValue === void 0) {
+ offValue = 0;
+ }
+ if (depth < 2) {
+ throw new Error("Error in oneHot: depth must be >=2, but it is " + depth);
+ }
+ var $indices = convertToTensor(indices, "indices", "oneHot", "int32");
+ var outShape = $indices.shape.concat([depth]);
+ var forward = function(backend2, save) {
+ save([$indices]);
+ return reshape(backend2.oneHot(reshape($indices, [$indices.size]), depth, onValue, offValue), outShape);
+ };
+ var inputs = {indices: $indices};
+ var attrs = {depth, onValue, offValue};
+ return ENGINE.runKernelFunc(forward, inputs, null, OneHot, attrs);
+ }
+ var oneHot = op({oneHot_});
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function transpose_(x, perm) {
+ var $x = convertToTensor(x, "x", "transpose");
+ if (perm == null) {
+ perm = $x.shape.map(function(s, i) {
+ return i;
+ }).reverse();
+ }
+ assert($x.rank === perm.length, function() {
+ return "Error in transpose: rank of input " + $x.rank + " " + ("must match length of perm " + perm + ".");
});
- var a = Bs(he(r, "int32"), e), s = Bs(he(i, "int32"), e), o = Tt(a), c = Ue(o, s);
- return he(c, "int32");
+ perm.forEach(function(axis) {
+ assert(axis >= 0 && axis < $x.rank, function() {
+ return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) + (" but got " + perm);
+ });
+ });
+ if ($x.rank <= 1) {
+ return $x.clone();
+ }
+ var inputs = {x: $x};
+ var attrs = {perm};
+ return ENGINE.runKernelFunc(function(backend2) {
+ return backend2.transpose($x, perm);
+ }, inputs, null, Transpose, attrs);
}
- var oI = U({confusionMatrix_: sI});
- var cI = {__proto__: null, confusionMatrix: oI};
- function Xm(n, t, e) {
- if (Ur(n), t != null && t.length !== 3)
+ var transpose = op({transpose_});
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function confusionMatrix_(labels, predictions, numClasses) {
+ var $labels = convertToTensor(labels, "labels", "confusionMatrix");
+ var $predictions = convertToTensor(predictions, "predictions", "confusionMatrix");
+ assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function() {
+ return "If provided, numClasses must be a positive integer, " + ("but got " + numClasses);
+ });
+ assert($labels.rank === 1, function() {
+ return "Expected the rank of labels to be 1, but got " + $labels.rank;
+ });
+ assert($predictions.rank === 1, function() {
+ return "Expected the rank of predictions to be 1, " + ("but got " + $predictions.rank);
+ });
+ assert($labels.shape[0] === $predictions.shape[0], function() {
+ return "Mismatch in the number of examples: " + ($labels.shape[0] + " vs. " + $predictions.shape[0] + ". ") + "Labels and predictions should have the same number of elements.";
+ });
+ assert(numClasses > 0 && Number.isInteger(numClasses), function() {
+ return "numClasses is required to be a positive integer, but got " + ("" + numClasses);
+ });
+ var oneHotLabels = oneHot(cast($labels, "int32"), numClasses);
+ var oneHotPredictions = oneHot(cast($predictions, "int32"), numClasses);
+ var oneHotLabelsT = transpose(oneHotLabels);
+ var product = matMul(oneHotLabelsT, oneHotPredictions);
+ return cast(product, "int32");
+ }
+ var confusionMatrix = op({confusionMatrix_});
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var math = {
+ __proto__: null,
+ confusionMatrix
+ };
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function tensor3d(values, shape, dtype) {
+ assertNonNull(values);
+ if (shape != null && shape.length !== 3) {
throw new Error("tensor3d() requires shape to have three numbers");
- var r = Rn(n, e);
- if (r.length !== 3 && r.length !== 1)
+ }
+ var inferredShape = inferShape(values, dtype);
+ if (inferredShape.length !== 3 && inferredShape.length !== 1) {
throw new Error("tensor3d() requires values to be number[][][] or flat/TypedArray");
- if (r.length === 1 && t == null)
+ }
+ if (inferredShape.length === 1 && shape == null) {
throw new Error("tensor3d() requires shape to be provided when `values` are a flat array");
- return ur(n, t, r, e);
+ }
+ return makeTensor(values, shape, inferredShape, dtype);
}
- var Ci;
- function lI(n, t) {
- if (t === void 0 && (t = 3), t > 4)
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var fromPixels2DContext;
+ function fromPixels_(pixels, numChannels) {
+ if (numChannels === void 0) {
+ numChannels = 3;
+ }
+ if (numChannels > 4) {
throw new Error("Cannot construct Tensor with more than 4 channels from pixels.");
- if (n == null)
+ }
+ if (pixels == null) {
throw new Error("pixels passed to tf.browser.fromPixels() can not be null");
- var e = false, r = false, i = false, a = false, s = false;
- if (n.data instanceof Uint8Array)
- e = true;
- else if (typeof ImageData != "undefined" && n instanceof ImageData)
- r = true;
- else if (typeof HTMLVideoElement != "undefined" && n instanceof HTMLVideoElement)
- i = true;
- else if (typeof HTMLImageElement != "undefined" && n instanceof HTMLImageElement)
- a = true;
- else if (n.getContext != null)
- s = true;
- else
- throw new Error("pixels passed to tf.browser.fromPixels() must be either an HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData in browser, or OffscreenCanvas, ImageData in webworker or {data: Uint32Array, width: number, height: number}, " + ("but was " + n.constructor.name));
- if (i) {
- var o = 2;
- if (i && n.readyState < o)
+ }
+ var isPixelData = false;
+ var isImageData = false;
+ var isVideo = false;
+ var isImage = false;
+ var isCanvasLike = false;
+ if (pixels.data instanceof Uint8Array) {
+ isPixelData = true;
+ } else if (typeof ImageData !== "undefined" && pixels instanceof ImageData) {
+ isImageData = true;
+ } else if (typeof HTMLVideoElement !== "undefined" && pixels instanceof HTMLVideoElement) {
+ isVideo = true;
+ } else if (typeof HTMLImageElement !== "undefined" && pixels instanceof HTMLImageElement) {
+ isImage = true;
+ } else if (pixels.getContext != null) {
+ isCanvasLike = true;
+ } else {
+ throw new Error("pixels passed to tf.browser.fromPixels() must be either an HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData in browser, or OffscreenCanvas, ImageData in webworker or {data: Uint32Array, width: number, height: number}, " + ("but was " + pixels.constructor.name));
+ }
+ if (isVideo) {
+ var HAVE_CURRENT_DATA_READY_STATE = 2;
+ if (isVideo && pixels.readyState < HAVE_CURRENT_DATA_READY_STATE) {
throw new Error("The video element has not loaded data yet. Please wait for `loadeddata` event on the