From a61b19ac0c20bbc1e984ad61572ce91b78bad6d2 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sun, 8 Nov 2020 09:56:02 -0500 Subject: [PATCH] update hand model --- config.js | 1 + demo/browser.js | 35 +- demo/menu.js | 1 + changelog.js => dev-server/changelog.js | 4 +- dev-server.crt => dev-server/dev-server.crt | 0 dev-server.js => dev-server/dev-server.js | 6 +- dev-server.key => dev-server/dev-server.key | 0 dist/demo-browser-index.js | 116986 ++++++++++++----- dist/demo-browser-index.js.map | 4 +- dist/demo-browser-index.json | 21 +- dist/human.esm-nobundle.js | 2 +- dist/human.esm-nobundle.js.map | 2 +- dist/human.esm-nobundle.json | 6 +- dist/human.esm.js | 95179 +++++++++++++- dist/human.esm.js.map | 4 +- dist/human.esm.json | 124 +- dist/human.js | 2 +- dist/human.js.map | 2 +- dist/human.json | 6 +- dist/human.node-nobundle.js | 2 +- dist/human.node-nobundle.js.map | 2 +- dist/human.node.js | 2 +- dist/human.node.js.map | 2 +- dist/human.node.json | 6 +- package.json | 8 +- src/hand/box.js | 2 +- src/hand/handdetector.js | 43 +- src/hand/handpipeline.js | 116 +- src/hand/handpose.js | 12 +- 29 files changed, 178666 insertions(+), 33914 deletions(-) rename changelog.js => dev-server/changelog.js (94%) rename dev-server.crt => dev-server/dev-server.crt (100%) rename dev-server.js => dev-server/dev-server.js (98%) rename dev-server.key => dev-server/dev-server.key (100%) diff --git a/config.js b/config.js index 716e6766..c49c1999 100644 --- a/config.js +++ b/config.js @@ -113,6 +113,7 @@ export default { scoreThreshold: 0.8, // threshold for deciding when to remove boxes based on score in non-maximum suppression enlargeFactor: 1.65, // empiric tuning as skeleton prediction prefers hand box with some whitespace maxHands: 1, // maximum number of hands detected in the input, should be set to the minimum number for performance + landmarks: true, // detect hand landmarks or just hand boundary box detector: { modelPath: '../models/handdetect.json', }, diff --git a/demo/browser.js b/demo/browser.js index 3fa25bc7..a5e436e1 100644 --- a/demo/browser.js +++ b/demo/browser.js @@ -27,6 +27,10 @@ const ui = { maxFrames: 10, modelsPreload: true, modelsWarmup: true, + menuWidth: 0, + menuHeight: 0, + camera: {}, + fps: [], }; // global variables @@ -34,8 +38,6 @@ let menu; let menuFX; let worker; let timeStamp; -let camera = {}; -const fps = []; // helper function: translates json to human readable string function str(...msg) { @@ -62,17 +64,22 @@ const status = (msg) => { // draws processed results and starts processing of a next frame function drawResults(input, result, canvas) { // update fps data - fps.push(1000 / (performance.now() - timeStamp)); - if (fps.length > ui.maxFrames) fps.shift(); + const elapsed = performance.now() - timeStamp; + ui.fps.push(1000 / elapsed); + if (ui.fps.length > ui.maxFrames) ui.fps.shift(); // enable for continous performance monitoring // console.log(result.performance); - // eslint-disable-next-line no-use-before-define - if (input.srcObject) requestAnimationFrame(() => runHumanDetect(input, canvas)); // immediate loop before we even draw results - + // immediate loop before we even draw results, but limit frame rate to 30 + if (input.srcObject) { + // eslint-disable-next-line no-use-before-define + if (elapsed > 33) requestAnimationFrame(() => runHumanDetect(input, canvas)); + // eslint-disable-next-line no-use-before-define + else setTimeout(() => runHumanDetect(input, canvas), 33 - elapsed); + } // draw fps chart - menu.updateChart('FPS', fps); + menu.updateChart('FPS', ui.fps); // draw image from video const ctx = canvas.getContext('2d'); ctx.fillStyle = ui.baseBackground; @@ -94,9 +101,9 @@ function drawResults(input, result, canvas) { const gpu = engine.backendInstance ? `gpu: ${(engine.backendInstance.numBytesInGPU ? engine.backendInstance.numBytesInGPU : 0).toLocaleString()} bytes` : ''; const memory = `system: ${engine.state.numBytes.toLocaleString()} bytes ${gpu} | tensors: ${engine.state.numTensors.toLocaleString()}`; const processing = result.canvas ? `processing: ${result.canvas.width} x ${result.canvas.height}` : ''; - const avg = Math.trunc(10 * fps.reduce((a, b) => a + b) / fps.length) / 10; + const avg = Math.trunc(10 * ui.fps.reduce((a, b) => a + b) / ui.fps.length) / 10; document.getElementById('log').innerText = ` - video: ${camera.name} | facing: ${camera.facing} | resolution: ${camera.width} x ${camera.height} ${processing} + video: ${ui.camera.name} | facing: ${ui.camera.facing} | resolution: ${ui.camera.width} x ${ui.camera.height} ${processing} backend: ${human.tf.getBackend()} | ${memory} performance: ${str(result.performance)} FPS:${avg} `; @@ -147,7 +154,7 @@ async function setupCamera() { const track = stream.getVideoTracks()[0]; const settings = track.getSettings(); log('camera constraints:', constraints, 'window:', { width: window.innerWidth, height: window.innerHeight }, 'settings:', settings, 'track:', track); - camera = { name: track.label, width: settings.width, height: settings.height, facing: settings.facingMode === 'user' ? 'front' : 'back' }; + ui.camera = { name: track.label, width: settings.width, height: settings.height, facing: settings.facingMode === 'user' ? 'front' : 'back' }; return new Promise((resolve) => { video.onloadeddata = async () => { video.width = video.videoWidth; @@ -156,6 +163,8 @@ async function setupCamera() { canvas.height = video.height; canvas.style.width = canvas.width > canvas.height ? '100vw' : ''; canvas.style.height = canvas.width > canvas.height ? '' : '100vh'; + ui.menuWidth.input.setAttribute('value', video.width); + ui.menuHeight.input.setAttribute('value', video.height); // silly font resizing for paint-on-canvas since viewport can be zoomed const size = 14 + (6 * canvas.width / window.innerWidth); ui.baseFont = ui.baseFontProto.replace(/{size}/, `${size}px`); @@ -351,8 +360,8 @@ function setupMenu() { menuFX.addHTML('
'); menuFX.addLabel('Image Processing'); menuFX.addBool('Enabled', human.config.filter, 'enabled'); - menuFX.addRange('Image width', human.config.filter, 'width', 0, 3840, 10, (val) => human.config.filter.width = parseInt(val)); - menuFX.addRange('Image height', human.config.filter, 'height', 0, 2160, 10, (val) => human.config.filter.height = parseInt(val)); + ui.menuWidth = menuFX.addRange('Image width', human.config.filter, 'width', 0, 3840, 10, (val) => human.config.filter.width = parseInt(val)); + ui.menuHeight = menuFX.addRange('Image height', human.config.filter, 'height', 0, 2160, 10, (val) => human.config.filter.height = parseInt(val)); menuFX.addRange('Brightness', human.config.filter, 'brightness', -1.0, 1.0, 0.05, (val) => human.config.filter.brightness = parseFloat(val)); menuFX.addRange('Contrast', human.config.filter, 'contrast', -1.0, 1.0, 0.05, (val) => human.config.filter.contrast = parseFloat(val)); menuFX.addRange('Sharpness', human.config.filter, 'sharpness', 0, 1.0, 0.05, (val) => human.config.filter.sharpness = parseFloat(val)); diff --git a/demo/menu.js b/demo/menu.js index 2e68ef1a..a5fb0b51 100644 --- a/demo/menu.js +++ b/demo/menu.js @@ -219,6 +219,7 @@ class Menu { evt.target.setAttribute('value', evt.target.value); if (callback) callback(evt.target.value); }); + el.input = el.children[0]; return el; } diff --git a/changelog.js b/dev-server/changelog.js similarity index 94% rename from changelog.js rename to dev-server/changelog.js index 34cc8fff..e64d6b60 100644 --- a/changelog.js +++ b/dev-server/changelog.js @@ -3,7 +3,7 @@ const path = require('path'); const dayjs = require('dayjs'); const simpleGit = require('simple-git/promise'); const logger = require('@vladmandic/pilogger'); -const app = require('./package.json'); +const app = require('../package.json'); const git = simpleGit(); @@ -45,5 +45,5 @@ async function update(f) { exports.update = update; if (!module.parent) { - update('wiki/Change-Log.md'); + update('../wiki/Change-Log.md'); } diff --git a/dev-server.crt b/dev-server/dev-server.crt similarity index 100% rename from dev-server.crt rename to dev-server/dev-server.crt diff --git a/dev-server.js b/dev-server/dev-server.js similarity index 98% rename from dev-server.js rename to dev-server/dev-server.js index ea065d19..bd823b32 100755 --- a/dev-server.js +++ b/dev-server/dev-server.js @@ -26,9 +26,9 @@ const log = require('@vladmandic/pilogger'); const options = { // key: fs.readFileSync('/home/vlado/dev/piproxy/cert/private.pem'), // cert: fs.readFileSync('/home/vlado/dev/piproxy/cert/fullchain.pem'), - key: fs.readFileSync('./dev-server.key'), - cert: fs.readFileSync('./dev-server.crt'), - root: '.', + key: fs.readFileSync('dev-server/dev-server.key'), + cert: fs.readFileSync('dev-server/dev-server.crt'), + root: '..', default: 'demo/index.html', port: 8000, monitor: ['package.json', 'config.js', 'demo', 'src'], diff --git a/dev-server.key b/dev-server/dev-server.key similarity index 100% rename from dev-server.key rename to dev-server/dev-server.key diff --git a/dist/demo-browser-index.js b/dist/demo-browser-index.js index 343f506c..84b65ec7 100644 --- a/dist/demo-browser-index.js +++ b/dist/demo-browser-index.js @@ -1,18071 +1,43578 @@ // dist/human.esm.js -var af = Object.defineProperty; -var hL = (n) => af(n, "__esModule", {value: true}); -var we = (n, t) => () => (t || (t = {exports: {}}, n(t.exports, t)), t.exports); -var Is = (n, t) => { - hL(n); - for (var e in t) - af(n, e, {get: t[e], enumerable: true}); -}; -var sf = we(() => { -}); -var of = we(() => { -}); -var As = we(() => { -}); -var Qr = we((A) => { - "use strict"; - Object.defineProperty(A, "__esModule", {value: true}); - var gc = function(n, t) { - return gc = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(e, r) { - e.__proto__ = r; - } || function(e, r) { - for (var i in r) - r.hasOwnProperty(i) && (e[i] = r[i]); - }, gc(n, t); - }; - function qn(n, t) { - gc(n, t); - function e() { - this.constructor = n; - } - n.prototype = t === null ? Object.create(t) : (e.prototype = t.prototype, new e()); +var __defProp = Object.defineProperty; +var __markAsModule = (target) => __defProp(target, "__esModule", {value: true}); +var __commonJS = (callback, module) => () => { + if (!module) { + module = {exports: {}}; + callback(module.exports, module); } - function pe(n, t, e, r) { - return new (e || (e = Promise))(function(i, a) { - function s(l) { + return module.exports; +}; +var __export = (target, all) => { + __markAsModule(target); + for (var name in all) + __defProp(target, name, {get: all[name], enumerable: true}); +}; +var require_browser = __commonJS(() => { +}); +var require_util = __commonJS(() => { +}); +var require_crypto = __commonJS(() => { +}); +var require_tf_core_node = __commonJS((exports) => { + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + "use strict"; + Object.defineProperty(exports, "__esModule", {value: true}); + /*! ***************************************************************************** + Copyright (c) Microsoft Corporation. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of the + License at http://www.apache.org/licenses/LICENSE-2.0 + + THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED + WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, + MERCHANTABLITY OR NON-INFRINGEMENT. + + See the Apache Version 2.0 License for specific language governing permissions + and limitations under the License. + ***************************************************************************** */ + var extendStatics = function(d, b) { + extendStatics = Object.setPrototypeOf || {__proto__: []} instanceof Array && function(d2, b2) { + d2.__proto__ = b2; + } || function(d2, b2) { + for (var p in b2) + if (b2.hasOwnProperty(p)) + d2[p] = b2[p]; + }; + return extendStatics(d, b); + }; + function __extends(d, b) { + extendStatics(d, b); + function __() { + this.constructor = d; + } + d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); + } + function __awaiter(thisArg, _arguments, P, generator) { + return new (P || (P = Promise))(function(resolve, reject) { + function fulfilled(value) { try { - c(r.next(l)); - } catch (u) { - a(u); + step2(generator.next(value)); + } catch (e) { + reject(e); } } - function o(l) { + function rejected(value) { try { - c(r.throw(l)); - } catch (u) { - a(u); + step2(generator["throw"](value)); + } catch (e) { + reject(e); } } - function c(l) { - l.done ? i(l.value) : new e(function(u) { - u(l.value); - }).then(s, o); + function step2(result) { + result.done ? resolve(result.value) : new P(function(resolve2) { + resolve2(result.value); + }).then(fulfilled, rejected); } - c((r = r.apply(n, t || [])).next()); + step2((generator = generator.apply(thisArg, _arguments || [])).next()); }); } - function fe(n, t) { - var e = {label: 0, sent: function() { - if (a[0] & 1) - throw a[1]; - return a[1]; - }, trys: [], ops: []}, r, i, a, s; - return s = {next: o(0), throw: o(1), return: o(2)}, typeof Symbol == "function" && (s[Symbol.iterator] = function() { + function __generator(thisArg, body) { + var _ = {label: 0, sent: function() { + if (t[0] & 1) + throw t[1]; + return t[1]; + }, trys: [], ops: []}, f, y, t, g; + return g = {next: verb(0), throw: verb(1), return: verb(2)}, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; - }), s; - function o(l) { - return function(u) { - return c([l, u]); + }), g; + function verb(n) { + return function(v) { + return step2([n, v]); }; } - function c(l) { - if (r) + function step2(op2) { + if (f) throw new TypeError("Generator is already executing."); - for (; e; ) + while (_) try { - if (r = 1, i && (a = l[0] & 2 ? i.return : l[0] ? i.throw || ((a = i.return) && a.call(i), 0) : i.next) && !(a = a.call(i, l[1])).done) - return a; - (i = 0, a) && (l = [l[0] & 2, a.value]); - switch (l[0]) { + if (f = 1, y && (t = op2[0] & 2 ? y["return"] : op2[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op2[1])).done) + return t; + if (y = 0, t) + op2 = [op2[0] & 2, t.value]; + switch (op2[0]) { case 0: case 1: - a = l; + t = op2; break; case 4: - return e.label++, {value: l[1], done: false}; + _.label++; + return {value: op2[1], done: false}; case 5: - e.label++, i = l[1], l = [0]; + _.label++; + y = op2[1]; + op2 = [0]; continue; case 7: - l = e.ops.pop(), e.trys.pop(); + op2 = _.ops.pop(); + _.trys.pop(); continue; default: - if (!(a = e.trys, a = a.length > 0 && a[a.length - 1]) && (l[0] === 6 || l[0] === 2)) { - e = 0; + if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op2[0] === 6 || op2[0] === 2)) { + _ = 0; continue; } - if (l[0] === 3 && (!a || l[1] > a[0] && l[1] < a[3])) { - e.label = l[1]; + if (op2[0] === 3 && (!t || op2[1] > t[0] && op2[1] < t[3])) { + _.label = op2[1]; break; } - if (l[0] === 6 && e.label < a[1]) { - e.label = a[1], a = l; + if (op2[0] === 6 && _.label < t[1]) { + _.label = t[1]; + t = op2; break; } - if (a && e.label < a[2]) { - e.label = a[2], e.ops.push(l); + if (t && _.label < t[2]) { + _.label = t[2]; + _.ops.push(op2); break; } - a[2] && e.ops.pop(), e.trys.pop(); + if (t[2]) + _.ops.pop(); + _.trys.pop(); continue; } - l = t.call(n, e); - } catch (u) { - l = [6, u], i = 0; + op2 = body.call(thisArg, _); + } catch (e) { + op2 = [6, e]; + y = 0; } finally { - r = a = 0; + f = t = 0; } - if (l[0] & 5) - throw l[1]; - return {value: l[0] ? l[1] : void 0, done: true}; + if (op2[0] & 5) + throw op2[1]; + return {value: op2[0] ? op2[1] : void 0, done: true}; } } - var dL = 1e-7, pL = 1e-4, fL = function() { - function n(t, e) { - this.backend = t, this.dataMover = e, this.data = new WeakMap(), this.dataIdsCount = 0; + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var EPSILON_FLOAT32 = 1e-7; + var EPSILON_FLOAT16 = 1e-4; + var DataStorage = function() { + function DataStorage2(backend2, dataMover) { + this.backend = backend2; + this.dataMover = dataMover; + this.data = new WeakMap(); + this.dataIdsCount = 0; } - return n.prototype.get = function(t) { - return this.data.has(t) || this.dataMover.moveData(this.backend, t), this.data.get(t); - }, n.prototype.set = function(t, e) { - this.dataIdsCount++, this.data.set(t, e); - }, n.prototype.has = function(t) { - return this.data.has(t); - }, n.prototype.delete = function(t) { - return this.dataIdsCount--, this.data.delete(t); - }, n.prototype.numDataIds = function() { + DataStorage2.prototype.get = function(dataId) { + if (!this.data.has(dataId)) { + this.dataMover.moveData(this.backend, dataId); + } + return this.data.get(dataId); + }; + DataStorage2.prototype.set = function(dataId, value) { + this.dataIdsCount++; + this.data.set(dataId, value); + }; + DataStorage2.prototype.has = function(dataId) { + return this.data.has(dataId); + }; + DataStorage2.prototype.delete = function(dataId) { + this.dataIdsCount--; + return this.data.delete(dataId); + }; + DataStorage2.prototype.numDataIds = function() { return this.dataIdsCount; - }, n; - }(), cf = function() { - function n() { - } - return n.prototype.time = function(t) { - return X("time"); - }, n.prototype.read = function(t) { - return X("read"); - }, n.prototype.readSync = function(t) { - return X("readSync"); - }, n.prototype.numDataIds = function() { - return X("numDataIds"); - }, n.prototype.disposeData = function(t) { - return X("disposeData"); - }, n.prototype.write = function(t, e, r) { - return X("write"); - }, n.prototype.move = function(t, e, r, i) { - return X("move"); - }, n.prototype.memory = function() { - return X("memory"); - }, n.prototype.floatPrecision = function() { - return X("floatPrecision"); - }, n.prototype.epsilon = function() { - return this.floatPrecision() === 32 ? dL : pL; - }, n.prototype.batchMatMul = function(t, e, r, i) { - return X("batchMatMul"); - }, n.prototype.fusedBatchMatMul = function(t) { - var e = t.a, r = t.b, i = t.transposeA, a = t.transposeB, s = t.bias, o = t.activation, c = t.preluActivationWeights; - return X("fusedBatchMatMul"); - }, n.prototype.slice = function(t, e, r) { - return X("slice"); - }, n.prototype.stridedSlice = function(t, e, r, i) { - return X("stridedSlice"); - }, n.prototype.unstack = function(t, e) { - return X("unstack"); - }, n.prototype.reverse = function(t, e) { - return X("reverse"); - }, n.prototype.concat = function(t, e) { - return X("concat"); - }, n.prototype.neg = function(t) { - return X("neg"); - }, n.prototype.add = function(t, e) { - return X("add"); - }, n.prototype.addN = function(t) { - return X("addN"); - }, n.prototype.subtract = function(t, e) { - return X("subtract"); - }, n.prototype.multiply = function(t, e) { - return X("multiply"); - }, n.prototype.realDivide = function(t, e) { - return X("realDivide"); - }, n.prototype.floorDiv = function(t, e) { - return X("floorDiv"); - }, n.prototype.sum = function(t, e) { - return X("sum"); - }, n.prototype.prod = function(t, e) { - return X("prod"); - }, n.prototype.unsortedSegmentSum = function(t, e, r) { - return X("unsortedSegmentSum"); - }, n.prototype.argMin = function(t, e) { - return X("argMin"); - }, n.prototype.argMax = function(t, e) { - return X("argMax"); - }, n.prototype.equal = function(t, e) { - return X("equal"); - }, n.prototype.notEqual = function(t, e) { - return X("notEqual"); - }, n.prototype.less = function(t, e) { - return X("less"); - }, n.prototype.lessEqual = function(t, e) { - return X("lessEqual"); - }, n.prototype.greater = function(t, e) { - return X("greater"); - }, n.prototype.greaterEqual = function(t, e) { - return X("greaterEqual"); - }, n.prototype.logicalNot = function(t) { - return X("logicalNot"); - }, n.prototype.logicalAnd = function(t, e) { - return X("logicalAnd"); - }, n.prototype.logicalOr = function(t, e) { - return X("logicalOr"); - }, n.prototype.where = function(t) { - return X("where"); - }, n.prototype.select = function(t, e, r) { - return X("select"); - }, n.prototype.topk = function(t, e, r) { - return X("topk"); - }, n.prototype.min = function(t, e) { - return X("min"); - }, n.prototype.minimum = function(t, e) { - return X("minimum"); - }, n.prototype.mod = function(t, e) { - return X("mod"); - }, n.prototype.max = function(t, e) { - return X("max"); - }, n.prototype.maximum = function(t, e) { - return X("maximum"); - }, n.prototype.all = function(t, e) { - return X("all"); - }, n.prototype.any = function(t, e) { - return X("any"); - }, n.prototype.squaredDifference = function(t, e) { - return X("squaredDifference"); - }, n.prototype.ceil = function(t) { - return X("ceil"); - }, n.prototype.floor = function(t) { - return X("floor"); - }, n.prototype.round = function(t) { - return X("round"); - }, n.prototype.sign = function(t) { - return X("sign"); - }, n.prototype.isNaN = function(t) { - return X("isNaN"); - }, n.prototype.isInf = function(t) { - return X("isInf"); - }, n.prototype.isFinite = function(t) { - return X("isFinite"); - }, n.prototype.pow = function(t, e) { - return X("pow"); - }, n.prototype.exp = function(t) { - return X("exp"); - }, n.prototype.expm1 = function(t) { - return X("expm1"); - }, n.prototype.softmax = function(t, e) { - return X("softmax"); - }, n.prototype.log = function(t) { - return X("log"); - }, n.prototype.log1p = function(t) { - return X("log1p"); - }, n.prototype.sqrt = function(t) { - return X("sqrt"); - }, n.prototype.rsqrt = function(t) { - return X("rsqrt"); - }, n.prototype.square = function(t) { - return X("square"); - }, n.prototype.reciprocal = function(t) { - return X("reciprocal"); - }, n.prototype.relu = function(t) { - return X("relu"); - }, n.prototype.relu6 = function(t) { - return X("relu6"); - }, n.prototype.prelu = function(t, e) { - return X("prelu"); - }, n.prototype.elu = function(t) { - return X("elu"); - }, n.prototype.eluDer = function(t, e) { - return X("eluDer"); - }, n.prototype.selu = function(t) { - return X("selu"); - }, n.prototype.int = function(t) { - return X("int"); - }, n.prototype.clip = function(t, e, r) { - return X("clip"); - }, n.prototype.abs = function(t) { - return X("abs"); - }, n.prototype.complexAbs = function(t) { - return X("complexAbs"); - }, n.prototype.sigmoid = function(t) { - return X("sigmoid"); - }, n.prototype.softplus = function(t) { - return X("softplus"); - }, n.prototype.sin = function(t) { - return X("sin"); - }, n.prototype.cos = function(t) { - return X("cos"); - }, n.prototype.tan = function(t) { - return X("tan"); - }, n.prototype.asin = function(t) { - return X("asin"); - }, n.prototype.acos = function(t) { - return X("acos"); - }, n.prototype.atan = function(t) { - return X("atan"); - }, n.prototype.atan2 = function(t, e) { - return X("atan2"); - }, n.prototype.sinh = function(t) { - return X("sinh"); - }, n.prototype.cosh = function(t) { - return X("cosh"); - }, n.prototype.tanh = function(t) { - return X("tanh"); - }, n.prototype.asinh = function(t) { - return X("asinh"); - }, n.prototype.acosh = function(t) { - return X("acosh"); - }, n.prototype.atanh = function(t) { - return X("atanh"); - }, n.prototype.erf = function(t) { - return X("erf"); - }, n.prototype.step = function(t, e) { - return X("step"); - }, n.prototype.fusedConv2d = function(t) { - var e = t.input, r = t.filter, i = t.convInfo, a = t.bias, s = t.activation, o = t.preluActivationWeights; - return X("fusedConv2d"); - }, n.prototype.conv2d = function(t, e, r) { - return X("conv2d"); - }, n.prototype.conv2dDerInput = function(t, e, r) { - return X("conv2dDerInput"); - }, n.prototype.conv2dDerFilter = function(t, e, r) { - return X("conv2dDerFilter"); - }, n.prototype.fusedDepthwiseConv2D = function(t) { - var e = t.input, r = t.filter, i = t.convInfo, a = t.bias, s = t.activation, o = t.preluActivationWeights; - return X("fusedDepthwiseConv2D"); - }, n.prototype.depthwiseConv2D = function(t, e, r) { - return X("depthwiseConv2D"); - }, n.prototype.depthwiseConv2DDerInput = function(t, e, r) { - return X("depthwiseConv2DDerInput"); - }, n.prototype.depthwiseConv2DDerFilter = function(t, e, r) { - return X("depthwiseConv2DDerFilter"); - }, n.prototype.conv3d = function(t, e, r) { - return X("conv3d"); - }, n.prototype.conv3dDerInput = function(t, e, r) { - return X("conv3dDerInput"); - }, n.prototype.conv3dDerFilter = function(t, e, r) { - return X("conv3dDerFilter"); - }, n.prototype.maxPool = function(t, e) { - return X("maxPool"); - }, n.prototype.maxPoolBackprop = function(t, e, r, i) { - return X("maxPoolBackprop"); - }, n.prototype.avgPool = function(t, e) { - return X("avgPool"); - }, n.prototype.avgPoolBackprop = function(t, e, r) { - return X("avgPoolBackprop"); - }, n.prototype.avgPool3d = function(t, e) { - return X("avgPool3d"); - }, n.prototype.avgPool3dBackprop = function(t, e, r) { - return X("avgPool3dBackprop"); - }, n.prototype.maxPool3d = function(t, e) { - return X("maxPool3d"); - }, n.prototype.maxPool3dBackprop = function(t, e, r, i) { - return X("maxPool3dBackprop"); - }, n.prototype.reshape = function(t, e) { - return X("reshape"); - }, n.prototype.cast = function(t, e) { - return X("cast"); - }, n.prototype.tile = function(t, e) { - return X("tile"); - }, n.prototype.pad = function(t, e, r) { - return X("pad"); - }, n.prototype.transpose = function(t, e) { - return X("transpose"); - }, n.prototype.gather = function(t, e, r) { - return X("gather"); - }, n.prototype.gatherND = function(t, e) { - return X("gatherND"); - }, n.prototype.scatterND = function(t, e, r) { - return X("scatterND"); - }, n.prototype.batchToSpaceND = function(t, e, r) { - return X("batchToSpaceND"); - }, n.prototype.spaceToBatchND = function(t, e, r) { - return X("spaceToBatchND"); - }, n.prototype.resizeBilinear = function(t, e, r, i) { - return X("resizeBilinear"); - }, n.prototype.resizeBilinearBackprop = function(t, e, r) { - return X("resizeBilinearBackprop"); - }, n.prototype.resizeNearestNeighbor = function(t, e, r, i) { - return X("resizeNearestNeighbor"); - }, n.prototype.resizeNearestNeighborBackprop = function(t, e, r) { - return X("resizeNearestNeighborBackprop"); - }, n.prototype.batchNorm = function(t, e, r, i, a, s) { - return X("batchNorm"); - }, n.prototype.localResponseNormalization4D = function(t, e, r, i, a) { - return X("localResponseNormalization4D"); - }, n.prototype.LRNGrad = function(t, e, r, i, a, s, o) { - return X("LRNGrad"); - }, n.prototype.multinomial = function(t, e, r, i) { - return X("multinomial"); - }, n.prototype.oneHot = function(t, e, r, i) { - return X("oneHot"); - }, n.prototype.cumsum = function(t, e, r, i) { - return X("cumsum"); - }, n.prototype.nonMaxSuppression = function(t, e, r, i, a) { - return X("nonMaxSuppression"); - }, n.prototype.fft = function(t) { - return X("fft"); - }, n.prototype.ifft = function(t) { - return X("ifft"); - }, n.prototype.complex = function(t, e) { - return X("complex"); - }, n.prototype.real = function(t) { - return X("real"); - }, n.prototype.imag = function(t) { - return X("imag"); - }, n.prototype.cropAndResize = function(t, e, r, i, a, s) { - return X("cropAndResize"); - }, n.prototype.depthToSpace = function(t, e, r) { - return X("depthToSpace"); - }, n.prototype.split = function(t, e, r) { - return X("split"); - }, n.prototype.sparseToDense = function(t, e, r, i) { - return X("sparseToDense"); - }, n.prototype.diag = function(t) { - return X("diag"); - }, n.prototype.fill = function(t, e, r) { - return X("fill"); - }, n.prototype.onesLike = function(t) { - return X("onesLike"); - }, n.prototype.zerosLike = function(t) { - return X("zerosLike"); - }, n.prototype.linspace = function(t, e, r) { - return X("linspace"); - }, n.prototype.dispose = function() { - return X("dispose"); - }, n; + }; + return DataStorage2; }(); - function X(n) { - throw new Error("'" + n + "' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen"); - } - function lf(n) { - for (var t = n.length, e = 0, r = 0; t > 0; ) - r = Math.random() * t | 0, t--, e = n[t], n[t] = n[r], n[r] = e; - } - function ga(n, t, e) { - return Math.max(n, Math.min(t, e)); - } - function mL(n) { - return n % 2 === 0 ? n : n + 1; - } - function gL(n) { - for (var t = 0, e = 0; e < n.length; e++) - t += n[e]; - return t; - } - function yL(n, t) { - var e = Math.random(); - return t * e + (1 - e) * n; - } - function vL(n, t) { - for (var e = 0, r = 0; r < n.length; r++) { - var i = Number(n[r]) - Number(t[r]); - e += i * i; + var KernelBackend = function() { + function KernelBackend2() { } - return e; + KernelBackend2.prototype.time = function(f) { + return notYetImplemented("time"); + }; + KernelBackend2.prototype.read = function(dataId) { + return notYetImplemented("read"); + }; + KernelBackend2.prototype.readSync = function(dataId) { + return notYetImplemented("readSync"); + }; + KernelBackend2.prototype.numDataIds = function() { + return notYetImplemented("numDataIds"); + }; + KernelBackend2.prototype.disposeData = function(dataId) { + return notYetImplemented("disposeData"); + }; + KernelBackend2.prototype.write = function(values, shape, dtype) { + return notYetImplemented("write"); + }; + KernelBackend2.prototype.move = function(dataId, values, shape, dtype) { + return notYetImplemented("move"); + }; + KernelBackend2.prototype.memory = function() { + return notYetImplemented("memory"); + }; + KernelBackend2.prototype.floatPrecision = function() { + return notYetImplemented("floatPrecision"); + }; + KernelBackend2.prototype.epsilon = function() { + return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; + }; + KernelBackend2.prototype.batchMatMul = function(a, b, transposeA, transposeB) { + return notYetImplemented("batchMatMul"); + }; + KernelBackend2.prototype.fusedBatchMatMul = function(_a) { + var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + return notYetImplemented("fusedBatchMatMul"); + }; + KernelBackend2.prototype.slice = function(x, begin, size) { + return notYetImplemented("slice"); + }; + KernelBackend2.prototype.stridedSlice = function(x, begin, end, strides) { + return notYetImplemented("stridedSlice"); + }; + KernelBackend2.prototype.unstack = function(x, axis) { + return notYetImplemented("unstack"); + }; + KernelBackend2.prototype.reverse = function(a, axis) { + return notYetImplemented("reverse"); + }; + KernelBackend2.prototype.concat = function(tensors, axis) { + return notYetImplemented("concat"); + }; + KernelBackend2.prototype.neg = function(a) { + return notYetImplemented("neg"); + }; + KernelBackend2.prototype.add = function(a, b) { + return notYetImplemented("add"); + }; + KernelBackend2.prototype.addN = function(tensors) { + return notYetImplemented("addN"); + }; + KernelBackend2.prototype.subtract = function(a, b) { + return notYetImplemented("subtract"); + }; + KernelBackend2.prototype.multiply = function(a, b) { + return notYetImplemented("multiply"); + }; + KernelBackend2.prototype.realDivide = function(a, b) { + return notYetImplemented("realDivide"); + }; + KernelBackend2.prototype.floorDiv = function(a, b) { + return notYetImplemented("floorDiv"); + }; + KernelBackend2.prototype.sum = function(x, axes) { + return notYetImplemented("sum"); + }; + KernelBackend2.prototype.prod = function(x, axes) { + return notYetImplemented("prod"); + }; + KernelBackend2.prototype.unsortedSegmentSum = function(x, segmentIds, numSegments) { + return notYetImplemented("unsortedSegmentSum"); + }; + KernelBackend2.prototype.argMin = function(x, axis) { + return notYetImplemented("argMin"); + }; + KernelBackend2.prototype.argMax = function(x, axis) { + return notYetImplemented("argMax"); + }; + KernelBackend2.prototype.equal = function(a, b) { + return notYetImplemented("equal"); + }; + KernelBackend2.prototype.notEqual = function(a, b) { + return notYetImplemented("notEqual"); + }; + KernelBackend2.prototype.less = function(a, b) { + return notYetImplemented("less"); + }; + KernelBackend2.prototype.lessEqual = function(a, b) { + return notYetImplemented("lessEqual"); + }; + KernelBackend2.prototype.greater = function(a, b) { + return notYetImplemented("greater"); + }; + KernelBackend2.prototype.greaterEqual = function(a, b) { + return notYetImplemented("greaterEqual"); + }; + KernelBackend2.prototype.logicalNot = function(a) { + return notYetImplemented("logicalNot"); + }; + KernelBackend2.prototype.logicalAnd = function(a, b) { + return notYetImplemented("logicalAnd"); + }; + KernelBackend2.prototype.logicalOr = function(a, b) { + return notYetImplemented("logicalOr"); + }; + KernelBackend2.prototype.where = function(condition) { + return notYetImplemented("where"); + }; + KernelBackend2.prototype.select = function(condition, a, b) { + return notYetImplemented("select"); + }; + KernelBackend2.prototype.topk = function(x, k, sorted) { + return notYetImplemented("topk"); + }; + KernelBackend2.prototype.min = function(x, axes) { + return notYetImplemented("min"); + }; + KernelBackend2.prototype.minimum = function(a, b) { + return notYetImplemented("minimum"); + }; + KernelBackend2.prototype.mod = function(a, b) { + return notYetImplemented("mod"); + }; + KernelBackend2.prototype.max = function(x, axes) { + return notYetImplemented("max"); + }; + KernelBackend2.prototype.maximum = function(a, b) { + return notYetImplemented("maximum"); + }; + KernelBackend2.prototype.all = function(x, axes) { + return notYetImplemented("all"); + }; + KernelBackend2.prototype.any = function(x, axes) { + return notYetImplemented("any"); + }; + KernelBackend2.prototype.squaredDifference = function(a, b) { + return notYetImplemented("squaredDifference"); + }; + KernelBackend2.prototype.ceil = function(x) { + return notYetImplemented("ceil"); + }; + KernelBackend2.prototype.floor = function(x) { + return notYetImplemented("floor"); + }; + KernelBackend2.prototype.round = function(x) { + return notYetImplemented("round"); + }; + KernelBackend2.prototype.sign = function(x) { + return notYetImplemented("sign"); + }; + KernelBackend2.prototype.isNaN = function(x) { + return notYetImplemented("isNaN"); + }; + KernelBackend2.prototype.isInf = function(x) { + return notYetImplemented("isInf"); + }; + KernelBackend2.prototype.isFinite = function(x) { + return notYetImplemented("isFinite"); + }; + KernelBackend2.prototype.pow = function(a, b) { + return notYetImplemented("pow"); + }; + KernelBackend2.prototype.exp = function(x) { + return notYetImplemented("exp"); + }; + KernelBackend2.prototype.expm1 = function(x) { + return notYetImplemented("expm1"); + }; + KernelBackend2.prototype.softmax = function(x, dim) { + return notYetImplemented("softmax"); + }; + KernelBackend2.prototype.log = function(x) { + return notYetImplemented("log"); + }; + KernelBackend2.prototype.log1p = function(x) { + return notYetImplemented("log1p"); + }; + KernelBackend2.prototype.sqrt = function(x) { + return notYetImplemented("sqrt"); + }; + KernelBackend2.prototype.rsqrt = function(x) { + return notYetImplemented("rsqrt"); + }; + KernelBackend2.prototype.square = function(x) { + return notYetImplemented("square"); + }; + KernelBackend2.prototype.reciprocal = function(x) { + return notYetImplemented("reciprocal"); + }; + KernelBackend2.prototype.relu = function(x) { + return notYetImplemented("relu"); + }; + KernelBackend2.prototype.relu6 = function(x) { + return notYetImplemented("relu6"); + }; + KernelBackend2.prototype.prelu = function(x, a) { + return notYetImplemented("prelu"); + }; + KernelBackend2.prototype.elu = function(x) { + return notYetImplemented("elu"); + }; + KernelBackend2.prototype.eluDer = function(dy, y) { + return notYetImplemented("eluDer"); + }; + KernelBackend2.prototype.selu = function(x) { + return notYetImplemented("selu"); + }; + KernelBackend2.prototype.int = function(x) { + return notYetImplemented("int"); + }; + KernelBackend2.prototype.clip = function(x, min2, max2) { + return notYetImplemented("clip"); + }; + KernelBackend2.prototype.abs = function(x) { + return notYetImplemented("abs"); + }; + KernelBackend2.prototype.complexAbs = function(x) { + return notYetImplemented("complexAbs"); + }; + KernelBackend2.prototype.sigmoid = function(x) { + return notYetImplemented("sigmoid"); + }; + KernelBackend2.prototype.softplus = function(x) { + return notYetImplemented("softplus"); + }; + KernelBackend2.prototype.sin = function(x) { + return notYetImplemented("sin"); + }; + KernelBackend2.prototype.cos = function(x) { + return notYetImplemented("cos"); + }; + KernelBackend2.prototype.tan = function(x) { + return notYetImplemented("tan"); + }; + KernelBackend2.prototype.asin = function(x) { + return notYetImplemented("asin"); + }; + KernelBackend2.prototype.acos = function(x) { + return notYetImplemented("acos"); + }; + KernelBackend2.prototype.atan = function(x) { + return notYetImplemented("atan"); + }; + KernelBackend2.prototype.atan2 = function(a, b) { + return notYetImplemented("atan2"); + }; + KernelBackend2.prototype.sinh = function(x) { + return notYetImplemented("sinh"); + }; + KernelBackend2.prototype.cosh = function(x) { + return notYetImplemented("cosh"); + }; + KernelBackend2.prototype.tanh = function(x) { + return notYetImplemented("tanh"); + }; + KernelBackend2.prototype.asinh = function(x) { + return notYetImplemented("asinh"); + }; + KernelBackend2.prototype.acosh = function(x) { + return notYetImplemented("acosh"); + }; + KernelBackend2.prototype.atanh = function(x) { + return notYetImplemented("atanh"); + }; + KernelBackend2.prototype.erf = function(x) { + return notYetImplemented("erf"); + }; + KernelBackend2.prototype.step = function(x, alpha) { + return notYetImplemented("step"); + }; + KernelBackend2.prototype.fusedConv2d = function(_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + return notYetImplemented("fusedConv2d"); + }; + KernelBackend2.prototype.conv2d = function(x, filter, convInfo) { + return notYetImplemented("conv2d"); + }; + KernelBackend2.prototype.conv2dDerInput = function(dy, filter, convInfo) { + return notYetImplemented("conv2dDerInput"); + }; + KernelBackend2.prototype.conv2dDerFilter = function(x, dY, convInfo) { + return notYetImplemented("conv2dDerFilter"); + }; + KernelBackend2.prototype.fusedDepthwiseConv2D = function(_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + return notYetImplemented("fusedDepthwiseConv2D"); + }; + KernelBackend2.prototype.depthwiseConv2D = function(input, filter, convInfo) { + return notYetImplemented("depthwiseConv2D"); + }; + KernelBackend2.prototype.depthwiseConv2DDerInput = function(dy, filter, convInfo) { + return notYetImplemented("depthwiseConv2DDerInput"); + }; + KernelBackend2.prototype.depthwiseConv2DDerFilter = function(x, dY, convInfo) { + return notYetImplemented("depthwiseConv2DDerFilter"); + }; + KernelBackend2.prototype.conv3d = function(x, filter, convInfo) { + return notYetImplemented("conv3d"); + }; + KernelBackend2.prototype.conv3dDerInput = function(dy, filter, convInfo) { + return notYetImplemented("conv3dDerInput"); + }; + KernelBackend2.prototype.conv3dDerFilter = function(x, dY, convInfo) { + return notYetImplemented("conv3dDerFilter"); + }; + KernelBackend2.prototype.maxPool = function(x, convInfo) { + return notYetImplemented("maxPool"); + }; + KernelBackend2.prototype.maxPoolBackprop = function(dy, x, y, convInfo) { + return notYetImplemented("maxPoolBackprop"); + }; + KernelBackend2.prototype.avgPool = function(x, convInfo) { + return notYetImplemented("avgPool"); + }; + KernelBackend2.prototype.avgPoolBackprop = function(dy, x, convInfo) { + return notYetImplemented("avgPoolBackprop"); + }; + KernelBackend2.prototype.avgPool3d = function(x, convInfo) { + return notYetImplemented("avgPool3d"); + }; + KernelBackend2.prototype.avgPool3dBackprop = function(dy, x, convInfo) { + return notYetImplemented("avgPool3dBackprop"); + }; + KernelBackend2.prototype.maxPool3d = function(x, convInfo) { + return notYetImplemented("maxPool3d"); + }; + KernelBackend2.prototype.maxPool3dBackprop = function(dy, x, y, convInfo) { + return notYetImplemented("maxPool3dBackprop"); + }; + KernelBackend2.prototype.reshape = function(x, shape) { + return notYetImplemented("reshape"); + }; + KernelBackend2.prototype.cast = function(x, dtype) { + return notYetImplemented("cast"); + }; + KernelBackend2.prototype.tile = function(x, reps) { + return notYetImplemented("tile"); + }; + KernelBackend2.prototype.pad = function(x, paddings, constantValue) { + return notYetImplemented("pad"); + }; + KernelBackend2.prototype.transpose = function(x, perm) { + return notYetImplemented("transpose"); + }; + KernelBackend2.prototype.gather = function(x, indices, axis) { + return notYetImplemented("gather"); + }; + KernelBackend2.prototype.gatherND = function(x, indices) { + return notYetImplemented("gatherND"); + }; + KernelBackend2.prototype.scatterND = function(indices, updates, shape) { + return notYetImplemented("scatterND"); + }; + KernelBackend2.prototype.batchToSpaceND = function(x, blockShape, crops) { + return notYetImplemented("batchToSpaceND"); + }; + KernelBackend2.prototype.spaceToBatchND = function(x, blockShape, paddings) { + return notYetImplemented("spaceToBatchND"); + }; + KernelBackend2.prototype.resizeBilinear = function(x, newHeight, newWidth, alignCorners) { + return notYetImplemented("resizeBilinear"); + }; + KernelBackend2.prototype.resizeBilinearBackprop = function(dy, x, alignCorners) { + return notYetImplemented("resizeBilinearBackprop"); + }; + KernelBackend2.prototype.resizeNearestNeighbor = function(x, newHEight, newWidth, alignCorners) { + return notYetImplemented("resizeNearestNeighbor"); + }; + KernelBackend2.prototype.resizeNearestNeighborBackprop = function(dy, x, alignCorners) { + return notYetImplemented("resizeNearestNeighborBackprop"); + }; + KernelBackend2.prototype.batchNorm = function(x, mean2, variance, offset, scale, varianceEpsilon) { + return notYetImplemented("batchNorm"); + }; + KernelBackend2.prototype.localResponseNormalization4D = function(x, radius, bias, alpha, beta) { + return notYetImplemented("localResponseNormalization4D"); + }; + KernelBackend2.prototype.LRNGrad = function(dy, inputImage, outputImage, radius, bias, alpha, beta) { + return notYetImplemented("LRNGrad"); + }; + KernelBackend2.prototype.multinomial = function(logits, normalized, numSamples, seed) { + return notYetImplemented("multinomial"); + }; + KernelBackend2.prototype.oneHot = function(indices, depth, onValue, offValue) { + return notYetImplemented("oneHot"); + }; + KernelBackend2.prototype.cumsum = function(x, axis, exclusive, reverse2) { + return notYetImplemented("cumsum"); + }; + KernelBackend2.prototype.nonMaxSuppression = function(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + return notYetImplemented("nonMaxSuppression"); + }; + KernelBackend2.prototype.fft = function(x) { + return notYetImplemented("fft"); + }; + KernelBackend2.prototype.ifft = function(x) { + return notYetImplemented("ifft"); + }; + KernelBackend2.prototype.complex = function(real2, imag2) { + return notYetImplemented("complex"); + }; + KernelBackend2.prototype.real = function(input) { + return notYetImplemented("real"); + }; + KernelBackend2.prototype.imag = function(input) { + return notYetImplemented("imag"); + }; + KernelBackend2.prototype.cropAndResize = function(image3, boxes, boxIndex, cropSize, method, extrapolationValue) { + return notYetImplemented("cropAndResize"); + }; + KernelBackend2.prototype.depthToSpace = function(x, blockSize, dataFormat) { + return notYetImplemented("depthToSpace"); + }; + KernelBackend2.prototype.split = function(value, sizeSplits, axis) { + return notYetImplemented("split"); + }; + KernelBackend2.prototype.sparseToDense = function(sparseIndices, sparseValues, outputShape, defaultValue) { + return notYetImplemented("sparseToDense"); + }; + KernelBackend2.prototype.diag = function(x) { + return notYetImplemented("diag"); + }; + KernelBackend2.prototype.fill = function(shape, value, dtype) { + return notYetImplemented("fill"); + }; + KernelBackend2.prototype.onesLike = function(x) { + return notYetImplemented("onesLike"); + }; + KernelBackend2.prototype.zerosLike = function(x) { + return notYetImplemented("zerosLike"); + }; + KernelBackend2.prototype.linspace = function(start, stop, num) { + return notYetImplemented("linspace"); + }; + KernelBackend2.prototype.dispose = function() { + return notYetImplemented("dispose"); + }; + return KernelBackend2; + }(); + function notYetImplemented(kernelName) { + throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen"); } - function E(n, t) { - if (!n) - throw new Error(typeof t == "string" ? t : t()); + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function shuffle(array) { + var counter = array.length; + var temp = 0; + var index = 0; + while (counter > 0) { + index = Math.random() * counter | 0; + counter--; + temp = array[counter]; + array[counter] = array[index]; + array[index] = temp; + } } - function Pe(n, t, e) { - e === void 0 && (e = ""), E(pn(n, t), function() { - return e + (" Shapes " + n + " and " + t + " must match"); + function clamp(min2, x, max2) { + return Math.max(min2, Math.min(x, max2)); + } + function nearestLargerEven(val) { + return val % 2 === 0 ? val : val + 1; + } + function sum(arr) { + var sum2 = 0; + for (var i = 0; i < arr.length; i++) { + sum2 += arr[i]; + } + return sum2; + } + function randUniform(a, b) { + var r = Math.random(); + return b * r + (1 - r) * a; + } + function distSquared(a, b) { + var result = 0; + for (var i = 0; i < a.length; i++) { + var diff = Number(a[i]) - Number(b[i]); + result += diff * diff; + } + return result; + } + function assert(expr, msg) { + if (!expr) { + throw new Error(typeof msg === "string" ? msg : msg()); + } + } + function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) { + if (errorMessagePrefix === void 0) { + errorMessagePrefix = ""; + } + assert(arraysEqual(shapeA, shapeB), function() { + return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"); }); } - function Ur(n) { - E(n != null, function() { + function assertNonNull(a) { + assert(a != null, function() { return "The input to the tensor constructor must be a non-null value."; }); } - function Br(n, t, e) { - if (t === void 0 && (t = []), e === void 0 && (e = false), t == null && (t = []), Array.isArray(n) || Ft(n) && !e) - for (var r = 0; r < n.length; ++r) - Br(n[r], t, e); - else - t.push(n); - return t; + function flatten(arr, result, skipTypedArray) { + if (result === void 0) { + result = []; + } + if (skipTypedArray === void 0) { + skipTypedArray = false; + } + if (result == null) { + result = []; + } + if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { + for (var i = 0; i < arr.length; ++i) { + flatten(arr[i], result, skipTypedArray); + } + } else { + result.push(arr); + } + return result; } - function pt(n) { - if (n.length === 0) + function sizeFromShape(shape) { + if (shape.length === 0) { return 1; - for (var t = n[0], e = 1; e < n.length; e++) - t *= n[e]; - return t; + } + var size = shape[0]; + for (var i = 1; i < shape.length; i++) { + size *= shape[i]; + } + return size; } - function wL(n) { - return n.length === 0; + function isScalarShape(shape) { + return shape.length === 0; } - function pn(n, t) { - if (n === t) + function arraysEqual(n1, n2) { + if (n1 === n2) { return true; - if (n == null || t == null) + } + if (n1 == null || n2 == null) { return false; - if (n.length !== t.length) + } + if (n1.length !== n2.length) { return false; - for (var e = 0; e < n.length; e++) - if (n[e] !== t[e]) + } + for (var i = 0; i < n1.length; i++) { + if (n1[i] !== n2[i]) { return false; + } + } return true; } - function ot(n) { - return n % 1 === 0; + function isInt(a) { + return a % 1 === 0; } - function bL(n) { - if (Math.tanh != null) - return Math.tanh(n); - if (n === Infinity) + function tanh(x) { + if (Math.tanh != null) { + return Math.tanh(x); + } + if (x === Infinity) { return 1; - if (n === -Infinity) + } else if (x === -Infinity) { return -1; - var t = Math.exp(2 * n); - return (t - 1) / (t + 1); + } else { + var e2x = Math.exp(2 * x); + return (e2x - 1) / (e2x + 1); + } } - function xL(n) { - var t = Math.ceil(Math.sqrt(n)); - return [t, Math.ceil(n / t)]; + function sizeToSquarishShape(size) { + var width = Math.ceil(Math.sqrt(size)); + return [width, Math.ceil(size / width)]; } - function LL(n) { - for (var t = new Uint32Array(n), e = 0; e < n; ++e) - t[e] = e; - return lf(t), t; + function createShuffledIndices(n) { + var shuffledIndices = new Uint32Array(n); + for (var i = 0; i < n; ++i) { + shuffledIndices[i] = i; + } + shuffle(shuffledIndices); + return shuffledIndices; } - function ya(n, t) { - return t <= n.length ? n : n + " ".repeat(t - n.length); + function rightPad(a, size) { + if (size <= a.length) { + return a; + } + return a + " ".repeat(size - a.length); } - function SL(n, t, e) { - return t === void 0 && (t = function(r) { - return 0; - }), new Promise(function(r, i) { - var a = 0, s = function() { - if (n()) { - r(); - return; - } - a++; - var o = t(a); - if (e != null && a >= e) { - i(); - return; - } - setTimeout(s, o); + function repeatedTry(checkFn, delayFn, maxCounter) { + if (delayFn === void 0) { + delayFn = function(counter) { + return 0; }; - s(); - }); - } - function uf(n, t) { - for (var e = 1, r = -1, i = 0; i < n.length; ++i) - if (n[i] >= 0) - e *= n[i]; - else if (n[i] === -1) { - if (r !== -1) - throw Error("Shapes can only have 1 implicit size. " + ("Found -1 at dim " + r + " and dim " + i)); - r = i; - } else if (n[i] < 0) - throw Error("Shapes can not be < 0. Found " + n[i] + " at dim " + i); - if (r === -1) { - if (t > 0 && t !== e) - throw Error("Size(" + t + ") must match the product of shape " + n); - return n; } - if (e === 0) - throw Error("Cannot infer the missing size in [" + n + "] when there are 0 elements"); - if (t % e !== 0) - throw Error("The implicit shape can't be a fractional number. " + ("Got " + t + " / " + e)); - var a = n.slice(); - return a[r] = t / e, a; - } - function rt(n, t) { - var e = t.length; - return n = n == null ? t.map(function(r, i) { - return i; - }) : [].concat(n), E(n.every(function(r) { - return r >= -e && r < e; - }), function() { - return "All values in axis param must be in range [-" + e + ", " + e + ") but " + ("got axis " + n); - }), E(n.every(function(r) { - return ot(r); - }), function() { - return "All values in axis param must be integers but " + ("got axis " + n); - }), n.map(function(r) { - return r < 0 ? e + r : r; + return new Promise(function(resolve, reject) { + var tryCount = 0; + var tryFn = function() { + if (checkFn()) { + resolve(); + return; + } + tryCount++; + var nextBackoff = delayFn(tryCount); + if (maxCounter != null && tryCount >= maxCounter) { + reject(); + return; + } + setTimeout(tryFn, nextBackoff); + }; + tryFn(); }); } - function hf(n, t) { - for (var e = [], r = [], i = t != null && Array.isArray(t) && t.length === 0, a = t == null || i ? null : rt(t, n).sort(), s = 0, o = 0; o < n.length; ++o) { - if (a != null) { - if (a[s] === o && n[o] !== 1) - throw new Error("Can't squeeze axis " + o + " since its dim '" + n[o] + "' is not 1"); - (a[s] == null || a[s] > o) && n[o] === 1 && (e.push(n[o]), r.push(o)), a[s] <= o && s++; + function inferFromImplicitShape(shape, size) { + var shapeProd = 1; + var implicitIdx = -1; + for (var i = 0; i < shape.length; ++i) { + if (shape[i] >= 0) { + shapeProd *= shape[i]; + } else if (shape[i] === -1) { + if (implicitIdx !== -1) { + throw Error("Shapes can only have 1 implicit size. " + ("Found -1 at dim " + implicitIdx + " and dim " + i)); + } + implicitIdx = i; + } else if (shape[i] < 0) { + throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i); } - n[o] !== 1 && (e.push(n[o]), r.push(o)); } - return {newShape: e, keptDims: r}; - } - function Ts(n, t) { - var e = null; - if (n == null || n === "float32") - e = new Float32Array(t); - else if (n === "int32") - e = new Int32Array(t); - else if (n === "bool") - e = new Uint8Array(t); - else - throw new Error("Unknown data type " + n); - return e; - } - function df(n, t) { - var e = null; - if (n == null || n === "float32") - e = new Float32Array(t); - else if (n === "int32") - e = new Int32Array(t); - else if (n === "bool") - e = new Uint8Array(t); - else if (n === "string") - e = new Array(t); - else - throw new Error("Unknown data type " + n); - return e; - } - function pf(n, t) { - for (var e = 0; e < n.length; e++) { - var r = n[e]; - if (isNaN(r) || !isFinite(r)) - throw Error("A tensor of type " + t + " being uploaded contains " + r + "."); + if (implicitIdx === -1) { + if (size > 0 && size !== shapeProd) { + throw Error("Size(" + size + ") must match the product of shape " + shape); + } + return shape; } + if (shapeProd === 0) { + throw Error("Cannot infer the missing size in [" + shape + "] when there are 0 elements"); + } + if (size % shapeProd !== 0) { + throw Error("The implicit shape can't be a fractional number. " + ("Got " + size + " / " + shapeProd)); + } + var newShape = shape.slice(); + newShape[implicitIdx] = size / shapeProd; + return newShape; } - function ff(n) { - return n === "bool" || n === "complex64" || n === "float32" || n === "int32" || n === "string"; - } - function mf(n, t) { - return t === "complex64" || (t === "float32" && n !== "complex64" || t === "int32" && n !== "float32" && n !== "complex64") ? false : !(t === "bool" && n === "bool"); - } - function Ft(n) { - return n instanceof Float32Array || n instanceof Int32Array || n instanceof Uint8Array; - } - function gf(n) { - if (n === "float32" || n === "int32") - return 4; - if (n === "complex64") - return 8; - if (n === "bool") - return 1; - throw new Error("Unknown dtype " + n); - } - function yf(n) { - if (n == null) - return 0; - var t = 0; - return n.forEach(function(e) { - return t += e.length; - }), t; - } - function or(n) { - return typeof n == "string" || n instanceof String; - } - function vf(n) { - return typeof n == "boolean"; - } - function wf(n) { - return typeof n == "number"; - } - function Ns(n) { - return Array.isArray(n) ? Ns(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array ? "int32" : wf(n) ? "float32" : or(n) ? "string" : vf(n) ? "bool" : "float32"; - } - function cr(n) { - return !!(n && n.constructor && n.call && n.apply); - } - function _s(n, t) { - for (var e = t; e < n; ++e) - if (n % e === 0) - return e; - return n; - } - function bi(n) { - var t = n.length; - if (t < 2) - return []; - var e = new Array(t - 1); - e[t - 2] = n[t - 1]; - for (var r = t - 3; r >= 0; --r) - e[r] = e[r + 1] * n[r + 1]; - return e; - } - function bf(n, t, e) { - var r = new Array(); - if (t.length === 1) - for (var i = t[0], a = 0; a < i; a++) - r[a] = e[n + a]; - else - for (var i = t[0], s = t.slice(1), o = s.reduce(function(l, u) { - return l * u; - }), a = 0; a < i; a++) - r[a] = bf(n + a * o, s, e); - return r; - } - function xi(n, t) { - if (n.length === 0) - return t[0]; - var e = n.reduce(function(r, i) { - return r * i; + function parseAxisParam(axis, shape) { + var rank = shape.length; + axis = axis == null ? shape.map(function(s, i) { + return i; + }) : [].concat(axis); + assert(axis.every(function(ax) { + return ax >= -rank && ax < rank; + }), function() { + return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " + ("got axis " + axis); }); - if (e === 0) + assert(axis.every(function(ax) { + return isInt(ax); + }), function() { + return "All values in axis param must be integers but " + ("got axis " + axis); + }); + return axis.map(function(a) { + return a < 0 ? rank + a : a; + }); + } + function squeezeShape(shape, axis) { + var newShape = []; + var keptDims = []; + var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; + var axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort(); + var j = 0; + for (var i = 0; i < shape.length; ++i) { + if (axes != null) { + if (axes[j] === i && shape[i] !== 1) { + throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1"); + } + if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { + newShape.push(shape[i]); + keptDims.push(i); + } + if (axes[j] <= i) { + j++; + } + } + if (shape[i] !== 1) { + newShape.push(shape[i]); + keptDims.push(i); + } + } + return {newShape, keptDims}; + } + function getTypedArrayFromDType(dtype, size) { + var values = null; + if (dtype == null || dtype === "float32") { + values = new Float32Array(size); + } else if (dtype === "int32") { + values = new Int32Array(size); + } else if (dtype === "bool") { + values = new Uint8Array(size); + } else { + throw new Error("Unknown data type " + dtype); + } + return values; + } + function getArrayFromDType(dtype, size) { + var values = null; + if (dtype == null || dtype === "float32") { + values = new Float32Array(size); + } else if (dtype === "int32") { + values = new Int32Array(size); + } else if (dtype === "bool") { + values = new Uint8Array(size); + } else if (dtype === "string") { + values = new Array(size); + } else { + throw new Error("Unknown data type " + dtype); + } + return values; + } + function checkConversionForErrors(vals, dtype) { + for (var i = 0; i < vals.length; i++) { + var num = vals[i]; + if (isNaN(num) || !isFinite(num)) { + throw Error("A tensor of type " + dtype + " being uploaded contains " + num + "."); + } + } + } + function isValidDtype(dtype) { + return dtype === "bool" || dtype === "complex64" || dtype === "float32" || dtype === "int32" || dtype === "string"; + } + function hasEncodingLoss(oldType, newType) { + if (newType === "complex64") { + return false; + } + if (newType === "float32" && oldType !== "complex64") { + return false; + } + if (newType === "int32" && oldType !== "float32" && oldType !== "complex64") { + return false; + } + if (newType === "bool" && oldType === "bool") { + return false; + } + return true; + } + function isTypedArray(a) { + return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array; + } + function bytesPerElement(dtype) { + if (dtype === "float32" || dtype === "int32") { + return 4; + } else if (dtype === "complex64") { + return 8; + } else if (dtype === "bool") { + return 1; + } else { + throw new Error("Unknown dtype " + dtype); + } + } + function bytesFromStringArray(arr) { + if (arr == null) { + return 0; + } + var bytes = 0; + arr.forEach(function(x) { + return bytes += x.length; + }); + return bytes; + } + function isString(value) { + return typeof value === "string" || value instanceof String; + } + function isBoolean(value) { + return typeof value === "boolean"; + } + function isNumber(value) { + return typeof value === "number"; + } + function inferDtype(values) { + if (Array.isArray(values)) { + return inferDtype(values[0]); + } + if (values instanceof Float32Array) { + return "float32"; + } else if (values instanceof Int32Array || values instanceof Uint8Array) { + return "int32"; + } else if (isNumber(values)) { + return "float32"; + } else if (isString(values)) { + return "string"; + } else if (isBoolean(values)) { + return "bool"; + } + return "float32"; + } + function isFunction(f) { + return !!(f && f.constructor && f.call && f.apply); + } + function nearestDivisor(size, start) { + for (var i = start; i < size; ++i) { + if (size % i === 0) { + return i; + } + } + return size; + } + function computeStrides(shape) { + var rank = shape.length; + if (rank < 2) { return []; - if (e !== t.length) - throw new Error("[" + n + "] does not match the input size " + t.length + "."); - return bf(0, n, t); + } + var strides = new Array(rank - 1); + strides[rank - 2] = shape[rank - 1]; + for (var i = rank - 3; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + return strides; } - function yc(n, t) { - for (var e = Li(n, t), r = 0; r < e.length; r++) - e[r] = 1; - return e; + function createNestedArray(offset, shape, a) { + var ret = new Array(); + if (shape.length === 1) { + var d = shape[0]; + for (var i = 0; i < d; i++) { + ret[i] = a[offset + i]; + } + } else { + var d = shape[0]; + var rest = shape.slice(1); + var len = rest.reduce(function(acc, c) { + return acc * c; + }); + for (var i = 0; i < d; i++) { + ret[i] = createNestedArray(offset + i * len, rest, a); + } + } + return ret; } - function Li(n, t) { - if (t == null || t === "float32" || t === "complex64") - return new Float32Array(n); - if (t === "int32") - return new Int32Array(n); - if (t === "bool") - return new Uint8Array(n); - throw new Error("Unknown data type " + t); + function toNestedArray(shape, a) { + if (shape.length === 0) { + return a[0]; + } + var size = shape.reduce(function(acc, c) { + return acc * c; + }); + if (size === 0) { + return []; + } + if (size !== a.length) { + throw new Error("[" + shape + "] does not match the input size " + a.length + "."); + } + return createNestedArray(0, shape, a); } - function IL(n, t) { - var e = n.reduce(function(r, i) { - return r * i; + function makeOnesTypedArray(size, dtype) { + var array = makeZerosTypedArray(size, dtype); + for (var i = 0; i < array.length; i++) { + array[i] = 1; + } + return array; + } + function makeZerosTypedArray(size, dtype) { + if (dtype == null || dtype === "float32" || dtype === "complex64") { + return new Float32Array(size); + } else if (dtype === "int32") { + return new Int32Array(size); + } else if (dtype === "bool") { + return new Uint8Array(size); + } else { + throw new Error("Unknown data type " + dtype); + } + } + function makeZerosNestedTypedArray(shape, dtype) { + var size = shape.reduce(function(prev, curr) { + return prev * curr; }, 1); - if (t == null || t === "float32") - return xi(n, new Float32Array(e)); - if (t === "int32") - return xi(n, new Int32Array(e)); - if (t === "bool") - return xi(n, new Uint8Array(e)); - throw new Error("Unknown data type " + t); + if (dtype == null || dtype === "float32") { + return toNestedArray(shape, new Float32Array(size)); + } else if (dtype === "int32") { + return toNestedArray(shape, new Int32Array(size)); + } else if (dtype === "bool") { + return toNestedArray(shape, new Uint8Array(size)); + } else { + throw new Error("Unknown data type " + dtype); + } } - function vc(n) { - n.forEach(function(t) { - E(Number.isInteger(t) && t >= 0, function() { - return "Tensor must have a shape comprised of positive integers but got " + ("shape [" + n + "]."); + function assertNonNegativeIntegerDimensions(shape) { + shape.forEach(function(dimSize) { + assert(Number.isInteger(dimSize) && dimSize >= 0, function() { + return "Tensor must have a shape comprised of positive integers but got " + ("shape [" + shape + "]."); }); }); } - function AL(n, t, e) { - if (t === 0) + function locToIndex(locs, rank, strides) { + if (rank === 0) { return 0; - if (t === 1) - return n[0]; - for (var r = n[n.length - 1], i = 0; i < n.length - 1; ++i) - r += e[i] * n[i]; - return r; - } - function TL(n, t, e) { - if (t === 0) - return []; - if (t === 1) - return [n]; - for (var r = new Array(t), i = 0; i < r.length - 1; ++i) - r[i] = Math.floor(n / e[i]), n -= r[i] * e[i]; - return r[r.length - 1] = n, r; - } - function wc(n) { - return n && n.then && typeof n.then == "function"; - } - var xf = "tfjsflags", Lf = function() { - function n(t) { - this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.populateURLFlags(); + } else if (rank === 1) { + return locs[0]; } - return n.prototype.setPlatform = function(t, e) { - this.platform != null && console.warn("Platform " + this.platformName + " has already been set. " + ("Overwriting the platform with " + e + ".")), this.platformName = t, this.platform = e; - }, n.prototype.registerFlag = function(t, e, r) { - if (this.flagRegistry[t] = {evaluationFn: e, setHook: r}, this.urlFlags[t] != null) { - var i = this.urlFlags[t]; - console.warn("Setting feature override from URL " + t + ": " + i + "."), this.set(t, i); + var index = locs[locs.length - 1]; + for (var i = 0; i < locs.length - 1; ++i) { + index += strides[i] * locs[i]; + } + return index; + } + function indexToLoc(index, rank, strides) { + if (rank === 0) { + return []; + } else if (rank === 1) { + return [index]; + } + var locs = new Array(rank); + for (var i = 0; i < locs.length - 1; ++i) { + locs[i] = Math.floor(index / strides[i]); + index -= locs[i] * strides[i]; + } + locs[locs.length - 1] = index; + return locs; + } + function isPromise(object) { + return object && object.then && typeof object.then === "function"; + } + /** + * @license + * Copyright 2017 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var TENSORFLOWJS_FLAGS_PREFIX = "tfjsflags"; + var Environment = function() { + function Environment2(global2) { + this.global = global2; + this.flags = {}; + this.flagRegistry = {}; + this.urlFlags = {}; + this.populateURLFlags(); + } + Environment2.prototype.setPlatform = function(platformName, platform) { + if (this.platform != null) { + console.warn("Platform " + this.platformName + " has already been set. " + ("Overwriting the platform with " + platform + ".")); } - }, n.prototype.getAsync = function(t) { - return pe(this, void 0, void 0, function() { - var e, r; - return fe(this, function(i) { - switch (i.label) { + this.platformName = platformName; + this.platform = platform; + }; + Environment2.prototype.registerFlag = function(flagName, evaluationFn, setHook) { + this.flagRegistry[flagName] = {evaluationFn, setHook}; + if (this.urlFlags[flagName] != null) { + var flagValue = this.urlFlags[flagName]; + console.warn("Setting feature override from URL " + flagName + ": " + flagValue + "."); + this.set(flagName, flagValue); + } + }; + Environment2.prototype.getAsync = function(flagName) { + return __awaiter(this, void 0, void 0, function() { + var _a, _b; + return __generator(this, function(_c) { + switch (_c.label) { case 0: - return t in this.flags ? [2, this.flags[t]] : (e = this.flags, r = t, [4, this.evaluateFlag(t)]); + if (flagName in this.flags) { + return [2, this.flags[flagName]]; + } + _a = this.flags; + _b = flagName; + return [4, this.evaluateFlag(flagName)]; case 1: - return e[r] = i.sent(), [2, this.flags[t]]; + _a[_b] = _c.sent(); + return [2, this.flags[flagName]]; } }); }); - }, n.prototype.get = function(t) { - if (t in this.flags) - return this.flags[t]; - var e = this.evaluateFlag(t); - if (wc(e)) - throw new Error("Flag " + t + " cannot be synchronously evaluated. Please use getAsync() instead."); - return this.flags[t] = e, this.flags[t]; - }, n.prototype.getNumber = function(t) { - return this.get(t); - }, n.prototype.getBool = function(t) { - return this.get(t); - }, n.prototype.getFlags = function() { + }; + Environment2.prototype.get = function(flagName) { + if (flagName in this.flags) { + return this.flags[flagName]; + } + var flagValue = this.evaluateFlag(flagName); + if (isPromise(flagValue)) { + throw new Error("Flag " + flagName + " cannot be synchronously evaluated. Please use getAsync() instead."); + } + this.flags[flagName] = flagValue; + return this.flags[flagName]; + }; + Environment2.prototype.getNumber = function(flagName) { + return this.get(flagName); + }; + Environment2.prototype.getBool = function(flagName) { + return this.get(flagName); + }; + Environment2.prototype.getFlags = function() { return this.flags; - }, Object.defineProperty(n.prototype, "features", {get: function() { - return this.flags; - }, enumerable: true, configurable: true}), n.prototype.set = function(t, e) { - if (this.flagRegistry[t] == null) - throw new Error("Cannot set flag " + t + " as it has not been registered."); - this.flags[t] = e, this.flagRegistry[t].setHook != null && this.flagRegistry[t].setHook(e); - }, n.prototype.evaluateFlag = function(t) { - if (this.flagRegistry[t] == null) - throw new Error("Cannot evaluate flag '" + t + "': no evaluation function found."); - return this.flagRegistry[t].evaluationFn(); - }, n.prototype.setFlags = function(t) { - this.flags = Object.assign({}, t); - }, n.prototype.reset = function() { - this.flags = {}, this.urlFlags = {}, this.populateURLFlags(); - }, n.prototype.populateURLFlags = function() { - var t = this; - if (typeof this.global == "undefined" || typeof this.global.location == "undefined" || typeof this.global.location.search == "undefined") + }; + Object.defineProperty(Environment2.prototype, "features", { + get: function() { + return this.flags; + }, + enumerable: true, + configurable: true + }); + Environment2.prototype.set = function(flagName, value) { + if (this.flagRegistry[flagName] == null) { + throw new Error("Cannot set flag " + flagName + " as it has not been registered."); + } + this.flags[flagName] = value; + if (this.flagRegistry[flagName].setHook != null) { + this.flagRegistry[flagName].setHook(value); + } + }; + Environment2.prototype.evaluateFlag = function(flagName) { + if (this.flagRegistry[flagName] == null) { + throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found."); + } + return this.flagRegistry[flagName].evaluationFn(); + }; + Environment2.prototype.setFlags = function(flags) { + this.flags = Object.assign({}, flags); + }; + Environment2.prototype.reset = function() { + this.flags = {}; + this.urlFlags = {}; + this.populateURLFlags(); + }; + Environment2.prototype.populateURLFlags = function() { + var _this = this; + if (typeof this.global === "undefined" || typeof this.global.location === "undefined" || typeof this.global.location.search === "undefined") { return; - var e = NL(this.global.location.search); - if (xf in e) { - var r = e[xf].split(","); - r.forEach(function(i) { - var a = i.split(":"), s = a[0], o = a[1]; - t.urlFlags[s] = _L(s, o); + } + var urlParams = getQueryParams(this.global.location.search); + if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { + var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(","); + keyValues.forEach(function(keyValue) { + var _a = keyValue.split(":"), key = _a[0], value = _a[1]; + _this.urlFlags[key] = parseValue(key, value); }); } - }, n; + }; + return Environment2; }(); - function NL(n) { - var t = {}; - return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function(e) { - for (var r = [], i = 1; i < arguments.length; i++) - r[i - 1] = arguments[i]; - return CL(t, r[0], r[1]), r.join("="); - }), t; + function getQueryParams(queryString) { + var params = {}; + queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function(s) { + var t = []; + for (var _i2 = 1; _i2 < arguments.length; _i2++) { + t[_i2 - 1] = arguments[_i2]; + } + decodeParam(params, t[0], t[1]); + return t.join("="); + }); + return params; } - function CL(n, t, e) { - n[decodeURIComponent(t)] = decodeURIComponent(e || ""); + function decodeParam(params, name, value) { + params[decodeURIComponent(name)] = decodeURIComponent(value || ""); } - function _L(n, t) { - if (t = t.toLowerCase(), t === "true" || t === "false") - return t === "true"; - if ("" + +t === t) - return +t; - throw new Error("Could not parse value flag value " + t + " for flag " + n + "."); - } - function Ge() { - return A.ENV; - } - A.ENV = null; - function RL(n) { - A.ENV = n; - } - var bc; - function Sf() { - if (bc == null) { - var n = void 0; - if (typeof window != "undefined") - n = window; - else if (typeof global != "undefined") - n = global; - else if (typeof process != "undefined") - n = process; - else if (typeof self != "undefined") - n = self; - else - throw new Error("Could not find a global object"); - bc = n; + function parseValue(flagName, value) { + value = value.toLowerCase(); + if (value === "true" || value === "false") { + return value === "true"; + } else if ("" + +value === value) { + return +value; } - return bc; + throw new Error("Could not parse value flag value " + value + " for flag " + flagName + "."); } - function OL() { - var n = Sf(); - return n._tfGlobals == null && (n._tfGlobals = new Map()), n._tfGlobals; + function env() { + return exports.ENV; } - function If(n, t) { - var e = OL(); - if (e.has(n)) - return e.get(n); - var r = t(); - return e.set(n, r), e.get(n); + exports.ENV = null; + function setEnvironmentGlobal(environment) { + exports.ENV = environment; } - var xc = "Abs", Lc = "Acos", Sc = "Acosh", Cs = "Add", Ic = "AddN", Af = "All", Tf = "Any", Ac = "ArgMax", Tc = "ArgMin", Nc = "Asin", _c = "Asinh", Cc = "Atan", Rc = "Atanh", Oc = "Atan2", Ec = "AvgPool", Nf = "AvgPoolBackprop", Dc = "AvgPool3D", _f = "AvgPool3DBackprop", kc = "BatchMatMul", Fc = "BatchToSpaceND", Wc = "BroadcastTo", Rs = "Cast", Uc = "Ceil", Bc = "ClipByValue", Cf = "Complex", zc = "Concat", Pc = "Conv2D", Rf = "Conv2DBackpropFilter", Mc = "Conv2DBackpropInput", Hc = "Conv3D", Of = "Conv3DBackpropFilterV2", Ef = "Conv3DBackpropInputV2", Vc = "Cos", Gc = "Cosh", qc = "Cumsum", Df = "CropAndResize", kf = "DepthToSpace", Yc = "DepthwiseConv2dNative", Ff = "DepthwiseConv2dNativeBackpropFilter", Wf = "DepthwiseConv2dNativeBackpropInput", Uf = "Diag", Kc = "Dilation2D", Bf = "Dilation2DBackpropInput", zf = "Dilation2DBackpropFilter", jc = "Div", $c = "Elu", Pf = "EluGrad", Xc = "Erf", Mf = "Equal", Jc = "Exp", Zc = "Expm1", Hf = "FFT", Vf = "Fill", Gf = "FlipLeftRight", Qc = "Floor", el = "FloorDiv", tl = "FusedBatchNorm", nl = "GatherV2", qf = "GatherNd", Yf = "Greater", rl = "GreaterEqual", il = "Identity", Kf = "IFFT", jf = "Imag", al = "IsFinite", sl = "IsInf", ol = "IsNan", $f = "Less", Xf = "LessEqual", Jf = "LinSpace", cl = "Log", ll = "Log1p", Zf = "LogicalAnd", Qf = "LogicalNot", em = "LogicalOr", ul = "LogSoftmax", hl = "LRN", tm = "LRNBackprop", dl = "Max", pl = "Maximum", fl = "MaxPool", nm = "MaxPoolBackprop", ml = "MaxPool3D", rm = "MaxPool3DBackprop", im = "MaxPoolWithArgmax", am = "Mean", gl = "Min", yl = "Minimum", vl = "MirrorPad", wl = "Mod", bl = "Multiply", xl = "Negate", sm = "NotEqual", om = "NonMaxSuppressionV3", cm = "NonMaxSuppressionV4", lm = "NonMaxSuppressionV5", Ll = "OnesLike", Sl = "OneHot", Il = "PadV2", EL = "Pool", Al = "Pow", Tl = "Prelu", um = "Prod", hm = "Range", dm = "Real", Nl = "Reciprocal", _l = "Relu", Cl = "Reshape", Rl = "ResizeNearestNeighbor", pm = "ResizeNearestNeighborGrad", Ol = "ResizeBilinear", fm = "ResizeBilinearGrad", El = "Relu6", Dl = "Reverse", kl = "Round", Fl = "Rsqrt", mm = "ScatterNd", Wl = "SelectV2", Ul = "Selu", Bl = "Slice", zl = "Sin", Pl = "Sinh", Ml = "Sign", Hl = "Sigmoid", Vl = "Softplus", Gl = "Sqrt", ql = "Sum", Yl = "SpaceToBatchND", Kl = "SplitV", jl = "Softmax", $l = "SquaredDifference", gm = "Square", Xl = "Sub", ym = "SparseToDense", vm = "StridedSlice", Jl = "Tan", Zl = "Tanh", Ql = "Tile", wm = "TopK", eu = "Transpose", bm = "Unique", tu = "Unpack", nu = "UnsortedSegmentSum", ru = "ZerosLike", iu = "Step", au = "FromPixels", xm = "RotateWithOffset", su = "_FusedMatMul", ou = "FusedConv2D", cu = "FusedDepthwiseConv2D"; - var Si = If("kernelRegistry", function() { - return new Map(); - }), va = If("gradRegistry", function() { + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var globalNameSpace; + function getGlobalNamespace() { + if (globalNameSpace == null) { + var ns = void 0; + if (typeof window !== "undefined") { + ns = window; + } else if (typeof global !== "undefined") { + ns = global; + } else if (typeof process !== "undefined") { + ns = process; + } else if (typeof self !== "undefined") { + ns = self; + } else { + throw new Error("Could not find a global object"); + } + globalNameSpace = ns; + } + return globalNameSpace; + } + function getGlobalMap() { + var ns = getGlobalNamespace(); + if (ns._tfGlobals == null) { + ns._tfGlobals = new Map(); + } + return ns._tfGlobals; + } + function getGlobal(key, init) { + var globalMap = getGlobalMap(); + if (globalMap.has(key)) { + return globalMap.get(key); + } else { + var singleton = init(); + globalMap.set(key, singleton); + return globalMap.get(key); + } + } + var Abs = "Abs"; + var Acos = "Acos"; + var Acosh = "Acosh"; + var Add = "Add"; + var AddN = "AddN"; + var All = "All"; + var Any = "Any"; + var ArgMax = "ArgMax"; + var ArgMin = "ArgMin"; + var Asin = "Asin"; + var Asinh = "Asinh"; + var Atan = "Atan"; + var Atanh = "Atanh"; + var Atan2 = "Atan2"; + var AvgPool = "AvgPool"; + var AvgPoolBackprop = "AvgPoolBackprop"; + var AvgPool3D = "AvgPool3D"; + var AvgPool3DBackprop = "AvgPool3DBackprop"; + var BatchMatMul = "BatchMatMul"; + var BatchToSpaceND = "BatchToSpaceND"; + var BroadcastTo = "BroadcastTo"; + var Cast = "Cast"; + var Ceil = "Ceil"; + var ClipByValue = "ClipByValue"; + var Complex = "Complex"; + var Concat = "Concat"; + var Conv2D = "Conv2D"; + var Conv2DBackpropFilter = "Conv2DBackpropFilter"; + var Conv2DBackpropInput = "Conv2DBackpropInput"; + var Conv3D = "Conv3D"; + var Conv3DBackpropFilterV2 = "Conv3DBackpropFilterV2"; + var Conv3DBackpropInputV2 = "Conv3DBackpropInputV2"; + var Cos = "Cos"; + var Cosh = "Cosh"; + var Cumsum = "Cumsum"; + var CropAndResize = "CropAndResize"; + var DepthToSpace = "DepthToSpace"; + var DepthwiseConv2dNative = "DepthwiseConv2dNative"; + var DepthwiseConv2dNativeBackpropFilter = "DepthwiseConv2dNativeBackpropFilter"; + var DepthwiseConv2dNativeBackpropInput = "DepthwiseConv2dNativeBackpropInput"; + var Diag = "Diag"; + var Dilation2D = "Dilation2D"; + var Dilation2DBackpropInput = "Dilation2DBackpropInput"; + var Dilation2DBackpropFilter = "Dilation2DBackpropFilter"; + var Div = "Div"; + var Elu = "Elu"; + var EluGrad = "EluGrad"; + var Erf = "Erf"; + var Equal = "Equal"; + var Exp = "Exp"; + var Expm1 = "Expm1"; + var FFT = "FFT"; + var Fill = "Fill"; + var FlipLeftRight = "FlipLeftRight"; + var Floor = "Floor"; + var FloorDiv = "FloorDiv"; + var FusedBatchNorm = "FusedBatchNorm"; + var GatherV2 = "GatherV2"; + var GatherNd = "GatherNd"; + var Greater = "Greater"; + var GreaterEqual = "GreaterEqual"; + var Identity = "Identity"; + var IFFT = "IFFT"; + var Imag = "Imag"; + var IsFinite = "IsFinite"; + var IsInf = "IsInf"; + var IsNan = "IsNan"; + var Less = "Less"; + var LessEqual = "LessEqual"; + var LinSpace = "LinSpace"; + var Log = "Log"; + var Log1p = "Log1p"; + var LogicalAnd = "LogicalAnd"; + var LogicalNot = "LogicalNot"; + var LogicalOr = "LogicalOr"; + var LogSoftmax = "LogSoftmax"; + var LRN = "LRN"; + var LRNBackprop = "LRNBackprop"; + var Max = "Max"; + var Maximum = "Maximum"; + var MaxPool = "MaxPool"; + var MaxPoolBackprop = "MaxPoolBackprop"; + var MaxPool3D = "MaxPool3D"; + var MaxPool3DBackprop = "MaxPool3DBackprop"; + var MaxPoolWithArgmax = "MaxPoolWithArgmax"; + var Mean = "Mean"; + var Min = "Min"; + var Minimum = "Minimum"; + var MirrorPad = "MirrorPad"; + var Mod = "Mod"; + var Multiply = "Multiply"; + var Negate = "Negate"; + var NotEqual = "NotEqual"; + var NonMaxSuppressionV3 = "NonMaxSuppressionV3"; + var NonMaxSuppressionV4 = "NonMaxSuppressionV4"; + var NonMaxSuppressionV5 = "NonMaxSuppressionV5"; + var OnesLike = "OnesLike"; + var OneHot = "OneHot"; + var PadV2 = "PadV2"; + var Pool = "Pool"; + var Pow = "Pow"; + var Prelu = "Prelu"; + var Prod = "Prod"; + var Range = "Range"; + var Real = "Real"; + var Reciprocal = "Reciprocal"; + var Relu = "Relu"; + var Reshape = "Reshape"; + var ResizeNearestNeighbor = "ResizeNearestNeighbor"; + var ResizeNearestNeighborGrad = "ResizeNearestNeighborGrad"; + var ResizeBilinear = "ResizeBilinear"; + var ResizeBilinearGrad = "ResizeBilinearGrad"; + var Relu6 = "Relu6"; + var Reverse = "Reverse"; + var Round = "Round"; + var Rsqrt = "Rsqrt"; + var ScatterNd = "ScatterNd"; + var SelectV2 = "SelectV2"; + var Selu = "Selu"; + var Slice = "Slice"; + var Sin = "Sin"; + var Sinh = "Sinh"; + var Sign = "Sign"; + var Sigmoid = "Sigmoid"; + var Softplus = "Softplus"; + var Sqrt = "Sqrt"; + var Sum = "Sum"; + var SpaceToBatchND = "SpaceToBatchND"; + var SplitV = "SplitV"; + var Softmax = "Softmax"; + var SquaredDifference = "SquaredDifference"; + var Square = "Square"; + var Sub = "Sub"; + var SparseToDense = "SparseToDense"; + var StridedSlice = "StridedSlice"; + var Tan = "Tan"; + var Tanh = "Tanh"; + var Tile = "Tile"; + var TopK = "TopK"; + var Transpose = "Transpose"; + var Unique = "Unique"; + var Unpack = "Unpack"; + var UnsortedSegmentSum = "UnsortedSegmentSum"; + var ZerosLike = "ZerosLike"; + var Step = "Step"; + var FromPixels = "FromPixels"; + var RotateWithOffset = "RotateWithOffset"; + var _FusedMatMul = "_FusedMatMul"; + var FusedConv2D = "FusedConv2D"; + var FusedDepthwiseConv2D = "FusedDepthwiseConv2D"; + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var kernelRegistry = getGlobal("kernelRegistry", function() { return new Map(); }); - function uu(n, t) { - var e = lu(n, t); - return Si.get(e); + var gradRegistry = getGlobal("gradRegistry", function() { + return new Map(); + }); + function getKernel(kernelName, backendName) { + var key = makeKey(kernelName, backendName); + return kernelRegistry.get(key); } - function hu(n) { - return va.get(n); + function getGradient(kernelName) { + return gradRegistry.get(kernelName); } - function Os(n) { - for (var t = Si.entries(), e = []; ; ) { - var r = t.next(), i = r.done, a = r.value; - if (i) + function getKernelsForBackend(backendName) { + var it = kernelRegistry.entries(); + var result = []; + while (true) { + var _a = it.next(), done = _a.done, value = _a.value; + if (done) { break; - var s = a[0], o = a[1], c = s.split("_")[0]; - c === n && e.push(o); + } + var key = value[0], config = value[1]; + var backend2 = key.split("_")[0]; + if (backend2 === backendName) { + result.push(config); + } } - return e; + return result; } - function Lm(n) { - var t = n.kernelName, e = n.backendName, r = lu(t, e); - Si.has(r) && console.warn("The kernel '" + t + "' for backend " + ("'" + e + "' is already registered")), Si.set(r, n); + function registerKernel(config) { + var kernelName = config.kernelName, backendName = config.backendName; + var key = makeKey(kernelName, backendName); + if (kernelRegistry.has(key)) { + console.warn("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is already registered")); + } + kernelRegistry.set(key, config); } - function Sm(n) { - var t = n.kernelName; - va.has(t) && (Ge().getBool("DEBUG") && console.warn("Overriding the gradient for '" + t + "'")), va.set(t, n); + function registerGradient(config) { + var kernelName = config.kernelName; + if (gradRegistry.has(kernelName)) { + if (env().getBool("DEBUG")) { + console.warn("Overriding the gradient for '" + kernelName + "'"); + } + } + gradRegistry.set(kernelName, config); } - function DL(n, t) { - var e = lu(n, t); - if (!Si.has(e)) - throw new Error("The kernel '" + n + "' for backend " + ("'" + t + "' is not registered")); - Si.delete(e); + function unregisterKernel(kernelName, backendName) { + var key = makeKey(kernelName, backendName); + if (!kernelRegistry.has(key)) { + throw new Error("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is not registered")); + } + kernelRegistry.delete(key); } - function kL(n) { - if (!va.has(n)) - throw new Error("The gradient '" + n + "' for backend is not registered"); - va.delete(n); + function unregisterGradient(kernelName) { + if (!gradRegistry.has(kernelName)) { + throw new Error("The gradient '" + kernelName + "' for backend is not registered"); + } + gradRegistry.delete(kernelName); } - function FL(n, t) { - var e = Os(n); - e.forEach(function(r) { - var i = Object.assign({}, r, {backendName: t}); - Lm(i); + function copyRegisteredKernels(registeredBackendName, newBackendName) { + var kernels = getKernelsForBackend(registeredBackendName); + kernels.forEach(function(kernelConfig) { + var newKernelConfig = Object.assign({}, kernelConfig, {backendName: newBackendName}); + registerKernel(newKernelConfig); }); } - function lu(n, t) { - return t + "_" + n; + function makeKey(kernelName, backendName) { + return backendName + "_" + kernelName; } - function WL(n, t) { - return t === "string" ? du(n) : Es([n], t); - } - function UL(n, t) { - return n instanceof Float32Array && t === "float32" || n instanceof Int32Array && t === "int32" || n instanceof Uint8Array && t === "bool"; - } - function Es(n, t) { - if (t === "string") - throw new Error("Cannot convert a string[] to a TypedArray"); - if (Array.isArray(n) && (n = Br(n)), Ge().getBool("DEBUG") && pf(n, t), UL(n, t)) - return n; - if (t == null || t === "float32" || t === "complex64") - return new Float32Array(n); - if (t === "int32") - return new Int32Array(n); - if (t === "bool") { - for (var e = new Uint8Array(n.length), r = 0; r < e.length; ++r) - Math.round(n[r]) !== 0 && (e[r] = 1); - return e; - } else - throw new Error("Unknown data type " + t); - } - function pu() { - return Ge().platform.now(); - } - function BL(n, t) { - return Ge().platform.fetch(n, t); - } - function du(n, t) { - return t === void 0 && (t = "utf-8"), t = t || "utf-8", Ge().platform.encode(n, t); - } - function fu(n, t) { - return t === void 0 && (t = "utf-8"), t = t || "utf-8", Ge().platform.decode(n, t); - } - var zL = {__proto__: null, createScalarValue: WL, toTypedArray: Es, now: pu, fetch: BL, encodeString: du, decodeString: fu, shuffle: lf, clamp: ga, nearestLargerEven: mL, sum: gL, randUniform: yL, distSquared: vL, assert: E, assertShapesMatch: Pe, assertNonNull: Ur, flatten: Br, sizeFromShape: pt, isScalarShape: wL, arraysEqual: pn, isInt: ot, tanh: bL, sizeToSquarishShape: xL, createShuffledIndices: LL, rightPad: ya, repeatedTry: SL, inferFromImplicitShape: uf, parseAxisParam: rt, squeezeShape: hf, getTypedArrayFromDType: Ts, getArrayFromDType: df, checkConversionForErrors: pf, isValidDtype: ff, hasEncodingLoss: mf, isTypedArray: Ft, bytesPerElement: gf, bytesFromStringArray: yf, isString: or, isBoolean: vf, isNumber: wf, inferDtype: Ns, isFunction: cr, nearestDivisor: _s, computeStrides: bi, toNestedArray: xi, makeOnesTypedArray: yc, makeZerosTypedArray: Li, makeZerosNestedTypedArray: IL, assertNonNegativeIntegerDimensions: vc, locToIndex: AL, indexToLoc: TL, isPromise: wc}; - var HL = function() { - function n(t, e) { - this.backendTimer = t, this.logger = e, e == null && (this.logger = new ML()); + /** + * @license + * Copyright 2017 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function createScalarValue(value, dtype) { + if (dtype === "string") { + return encodeString(value); } - return n.prototype.profileKernel = function(t, e, r) { - for (var i, a = function() { - i = r(); - }, s = this.backendTimer.time(a), o = function(u) { - var h = i[u]; - h.data().then(function(d) { - PL(d, h.dtype, t); + return toTypedArray([value], dtype); + } + function noConversionNeeded(a, dtype) { + return a instanceof Float32Array && dtype === "float32" || a instanceof Int32Array && dtype === "int32" || a instanceof Uint8Array && dtype === "bool"; + } + function toTypedArray(a, dtype) { + if (dtype === "string") { + throw new Error("Cannot convert a string[] to a TypedArray"); + } + if (Array.isArray(a)) { + a = flatten(a); + } + if (env().getBool("DEBUG")) { + checkConversionForErrors(a, dtype); + } + if (noConversionNeeded(a, dtype)) { + return a; + } + if (dtype == null || dtype === "float32" || dtype === "complex64") { + return new Float32Array(a); + } else if (dtype === "int32") { + return new Int32Array(a); + } else if (dtype === "bool") { + var bool = new Uint8Array(a.length); + for (var i = 0; i < bool.length; ++i) { + if (Math.round(a[i]) !== 0) { + bool[i] = 1; + } + } + return bool; + } else { + throw new Error("Unknown data type " + dtype); + } + } + function now2() { + return env().platform.now(); + } + function fetch$1(path, requestInits) { + return env().platform.fetch(path, requestInits); + } + function encodeString(s, encoding) { + if (encoding === void 0) { + encoding = "utf-8"; + } + encoding = encoding || "utf-8"; + return env().platform.encode(s, encoding); + } + function decodeString(bytes, encoding) { + if (encoding === void 0) { + encoding = "utf-8"; + } + encoding = encoding || "utf-8"; + return env().platform.decode(bytes, encoding); + } + var util = { + __proto__: null, + createScalarValue, + toTypedArray, + now: now2, + fetch: fetch$1, + encodeString, + decodeString, + shuffle, + clamp, + nearestLargerEven, + sum, + randUniform, + distSquared, + assert, + assertShapesMatch, + assertNonNull, + flatten, + sizeFromShape, + isScalarShape, + arraysEqual, + isInt, + tanh, + sizeToSquarishShape, + createShuffledIndices, + rightPad, + repeatedTry, + inferFromImplicitShape, + parseAxisParam, + squeezeShape, + getTypedArrayFromDType, + getArrayFromDType, + checkConversionForErrors, + isValidDtype, + hasEncodingLoss, + isTypedArray, + bytesPerElement, + bytesFromStringArray, + isString, + isBoolean, + isNumber, + inferDtype, + isFunction, + nearestDivisor, + computeStrides, + toNestedArray, + makeOnesTypedArray, + makeZerosTypedArray, + makeZerosNestedTypedArray, + assertNonNegativeIntegerDimensions, + locToIndex, + indexToLoc, + isPromise + }; + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var Profiler = function() { + function Profiler2(backendTimer, logger) { + this.backendTimer = backendTimer; + this.logger = logger; + if (logger == null) { + this.logger = new Logger(); + } + } + Profiler2.prototype.profileKernel = function(kernelName, inputs, f) { + var outputs; + var holdResultWrapperFn = function() { + outputs = f(); + }; + var timer = this.backendTimer.time(holdResultWrapperFn); + var _loop_1 = function(i2) { + var output = outputs[i2]; + output.data().then(function(tensorVals) { + checkComputationForErrors(tensorVals, output.dtype, kernelName); }); - }, c = 0; c < i.length; c++) - o(c); - var l = {kernelName: t, outputs: i, inputs: e, timeMs: s.then(function(u) { - return u.kernelMs; - }), extraInfo: s.then(function(u) { - return u.getExtraProfileInfo != null ? u.getExtraProfileInfo() : ""; - })}; - return l; - }, n.prototype.logKernelProfile = function(t) { - var e = this, r = t.kernelName, i = t.outputs, a = t.timeMs, s = t.inputs, o = t.extraInfo; - i.forEach(function(c) { - Promise.all([c.data(), a, o]).then(function(l) { - e.logger.logKernelProfile(r, c, l[0], l[1], s, l[2]); + }; + for (var i = 0; i < outputs.length; i++) { + _loop_1(i); + } + var kernelProfile = { + kernelName, + outputs, + inputs, + timeMs: timer.then(function(timing) { + return timing.kernelMs; + }), + extraInfo: timer.then(function(timing) { + return timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : ""; + }) + }; + return kernelProfile; + }; + Profiler2.prototype.logKernelProfile = function(kernelProfile) { + var _this = this; + var kernelName = kernelProfile.kernelName, outputs = kernelProfile.outputs, timeMs = kernelProfile.timeMs, inputs = kernelProfile.inputs, extraInfo = kernelProfile.extraInfo; + outputs.forEach(function(result) { + Promise.all([result.data(), timeMs, extraInfo]).then(function(valueContainer) { + _this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]); }); }); - }, n; + }; + return Profiler2; }(); - function PL(n, t, e) { - if (t !== "float32") + function checkComputationForErrors(vals, dtype, kernelName) { + if (dtype !== "float32") { return false; - for (var r = 0; r < n.length; r++) { - var i = n[r]; - if (isNaN(i) || !isFinite(i)) - return console.warn("Found " + i + " in the result of '" + e + "'"), true; + } + for (var i = 0; i < vals.length; i++) { + var num = vals[i]; + if (isNaN(num) || !isFinite(num)) { + console.warn("Found " + num + " in the result of '" + kernelName + "'"); + return true; + } } return false; } - var ML = function() { - function n() { + var Logger = function() { + function Logger2() { } - return n.prototype.logKernelProfile = function(t, e, r, i, a, s) { - var o = typeof i == "number" ? ya(i + "ms", 9) : i.error, c = ya(t, 25), l = e.rank, u = e.size, h = ya(e.shape.toString(), 14), d = ""; - for (var p in a) { - var f = a[p]; - if (f != null) { - var m = f.shape || e.shape, g = m.length; - d += p + ": " + g + "D " + (g > 0 ? m : "") + " "; + Logger2.prototype.logKernelProfile = function(name, result, vals, timeMs, inputs, extraInfo) { + var time2 = typeof timeMs === "number" ? rightPad(timeMs + "ms", 9) : timeMs["error"]; + var paddedName = rightPad(name, 25); + var rank = result.rank; + var size = result.size; + var shape = rightPad(result.shape.toString(), 14); + var inputShapesDescription = ""; + for (var name_1 in inputs) { + var input = inputs[name_1]; + if (input != null) { + var inputShape = input.shape || result.shape; + var inputRank = inputShape.length; + inputShapesDescription += name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : "") + " "; } } - console.log("%c" + c + " %c" + o + " %c" + l + "D " + h + " %c" + u + " %c" + d + " %c" + s, "font-weight:bold", "color:red", "color:blue", "color: orange", "color: green", "color: steelblue"); - }, n; + console.log("%c" + paddedName + " %c" + time2 + " %c" + rank + "D " + shape + " %c" + size + " %c" + inputShapesDescription + " %c" + extraInfo, "font-weight:bold", "color:red", "color:blue", "color: orange", "color: green", "color: steelblue"); + }; + return Logger2; }(); - function VL(n, t, e) { - for (var r = {}, i = {}, a = 0; a < t.length; a++) - r[t[a].id] = true; - for (var a = 0; a < n.length; a++) { - var s = n[a], o = s.inputs; - for (var c in o) { - for (var l = o[c], u = false, h = 0; h < t.length; h++) - if (r[l.id]) { - s.outputs.forEach(function(b) { - return r[b.id] = true; - }), u = true, i[s.id] = true; + /** + * @license + * Copyright 2017 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function getFilteredNodesXToY(tape, xs, y) { + var tensorsFromX = {}; + var nodesFromX = {}; + for (var i = 0; i < xs.length; i++) { + tensorsFromX[xs[i].id] = true; + } + for (var i = 0; i < tape.length; i++) { + var node = tape[i]; + var nodeInputs = node.inputs; + for (var inputName in nodeInputs) { + var input = nodeInputs[inputName]; + var anyInputFromX = false; + for (var j = 0; j < xs.length; j++) { + if (tensorsFromX[input.id]) { + node.outputs.forEach(function(output) { + return tensorsFromX[output.id] = true; + }); + anyInputFromX = true; + nodesFromX[node.id] = true; break; } - if (u) - break; - } - } - var d = {}; - d[e.id] = true; - for (var p = {}, a = n.length - 1; a >= 0; a--) - for (var s = n[a], o = s.inputs, h = 0; h < s.outputs.length; h++) - if (d[s.outputs[h].id]) { - for (var c in o) - d[o[c].id] = true, p[s.id] = true; + } + if (anyInputFromX) { break; } - for (var f = [], a = 0; a < n.length; a++) { - var s = n[a]; - if (i[s.id] && p[s.id]) { - var m = {}; - for (var c in s.inputs) { - var g = s.inputs[c]; - r[g.id] && (m[c] = g); - } - var y = Object.assign({}, s); - y.inputs = m, y.outputs = s.outputs, f.push(y); } } - return f; + var tensorsLeadToY = {}; + tensorsLeadToY[y.id] = true; + var nodesToY = {}; + for (var i = tape.length - 1; i >= 0; i--) { + var node = tape[i]; + var nodeInputs = node.inputs; + for (var j = 0; j < node.outputs.length; j++) { + if (tensorsLeadToY[node.outputs[j].id]) { + for (var inputName in nodeInputs) { + tensorsLeadToY[nodeInputs[inputName].id] = true; + nodesToY[node.id] = true; + } + break; + } + } + } + var filteredTape = []; + for (var i = 0; i < tape.length; i++) { + var node = tape[i]; + if (nodesFromX[node.id] && nodesToY[node.id]) { + var prunedInputs = {}; + for (var inputName in node.inputs) { + var nodeInput = node.inputs[inputName]; + if (tensorsFromX[nodeInput.id]) { + prunedInputs[inputName] = nodeInput; + } + } + var prunedNode = Object.assign({}, node); + prunedNode.inputs = prunedInputs; + prunedNode.outputs = node.outputs; + filteredTape.push(prunedNode); + } + } + return filteredTape; } - function GL(n, t, e, r) { - for (var i = function(s) { - var o = t[s], c = []; - if (o.outputs.forEach(function(d) { - var p = n[d.id]; - p != null ? c.push(p) : c.push(null); - }), o.gradient == null) - throw new Error("Cannot compute gradient: gradient function not found " + ("for " + o.kernelName + ".")); - var l = o.gradient(c), u = function(d) { - if (!(d in l)) - throw new Error("Cannot backprop through input " + d + ". " + ("Available gradients found: " + Object.keys(l) + ".")); - var p = e(function() { - return l[d](); + function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy2, add2) { + var _loop_1 = function(i2) { + var node = filteredTape[i2]; + var dys = []; + node.outputs.forEach(function(o) { + var gradTensor = tensorAccumulatedGradientMap[o.id]; + if (gradTensor != null) { + dys.push(gradTensor); + } else { + dys.push(null); + } + }); + if (node.gradient == null) { + throw new Error("Cannot compute gradient: gradient function not found " + ("for " + node.kernelName + ".")); + } + var inputGradients = node.gradient(dys); + var _loop_2 = function(inputName2) { + if (!(inputName2 in inputGradients)) { + throw new Error("Cannot backprop through input " + inputName2 + ". " + ("Available gradients found: " + Object.keys(inputGradients) + ".")); + } + var dx = tidy2(function() { + return inputGradients[inputName2](); }); - if (p.dtype !== "float32") - throw new Error("Error in gradient for op " + o.kernelName + ". The gradient of input " + (d + " must have 'float32' dtype, but has '" + p.dtype + "'")); - var f = o.inputs[d]; - if (!pn(p.shape, f.shape)) - throw new Error("Error in gradient for op " + o.kernelName + ". The gradient of input " + ("'" + d + "' has shape '" + p.shape + "', which does not match ") + ("the shape of the input '" + f.shape + "'")); - if (n[f.id] == null) - n[f.id] = p; - else { - var m = n[f.id]; - n[f.id] = r(m, p), m.dispose(); + if (dx.dtype !== "float32") { + throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + (inputName2 + " must have 'float32' dtype, but has '" + dx.dtype + "'")); + } + var x = node.inputs[inputName2]; + if (!arraysEqual(dx.shape, x.shape)) { + throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + ("'" + inputName2 + "' has shape '" + dx.shape + "', which does not match ") + ("the shape of the input '" + x.shape + "'")); + } + if (tensorAccumulatedGradientMap[x.id] == null) { + tensorAccumulatedGradientMap[x.id] = dx; + } else { + var curGradient = tensorAccumulatedGradientMap[x.id]; + tensorAccumulatedGradientMap[x.id] = add2(curGradient, dx); + curGradient.dispose(); } }; - for (var h in o.inputs) - u(h); - }, a = t.length - 1; a >= 0; a--) - i(a); + for (var inputName in node.inputs) { + _loop_2(inputName); + } + }; + for (var i = filteredTape.length - 1; i >= 0; i--) { + _loop_1(i); + } } - var Im = 20, wa = 3, mu = 7; - function YL(n, t, e, r) { - var i = bi(t), a = qL(n, t, e, i), s = t.length, o = Ds(n, t, e, i, a), c = ["Tensor"]; - return r && (c.push(" dtype: " + e), c.push(" rank: " + s), c.push(" shape: [" + t + "]"), c.push(" values:")), c.push(o.map(function(l) { + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var FORMAT_LIMIT_NUM_VALS = 20; + var FORMAT_NUM_FIRST_LAST_VALS = 3; + var FORMAT_NUM_SIG_DIGITS = 7; + function tensorToString(vals, shape, dtype, verbose) { + var strides = computeStrides(shape); + var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); + var rank = shape.length; + var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); + var lines = ["Tensor"]; + if (verbose) { + lines.push(" dtype: " + dtype); + lines.push(" rank: " + rank); + lines.push(" shape: [" + shape + "]"); + lines.push(" values:"); + } + lines.push(valsLines.map(function(l) { return " " + l; - }).join(` -`)), c.join(` -`); + }).join("\n")); + return lines.join("\n"); } - function qL(n, t, e, r) { - var i = pt(t), a = r[r.length - 1], s = new Array(a).fill(0), o = t.length, c = e === "complex64" ? xa(n) : n; - if (o > 1) - for (var l = 0; l < i / a; l++) - for (var u = l * a, h = 0; h < a; h++) - s[h] = Math.max(s[h], ba(c[u + h], 0, e).length); - return s; - } - function ba(n, t, e) { - var r; - return Array.isArray(n) ? r = parseFloat(n[0].toFixed(mu)) + " + " + (parseFloat(n[1].toFixed(mu)) + "j") : or(n) ? r = "'" + n + "'" : e === "bool" ? r = Am(n) : r = parseFloat(n.toFixed(mu)).toString(), ya(r, t); - } - function Am(n) { - return n === 0 ? "false" : "true"; - } - function Ds(n, t, e, r, i, a) { - a === void 0 && (a = true); - var s = e === "complex64" ? 2 : 1, o = t[0], c = t.length; - if (c === 0) { - if (e === "complex64") { - var l = xa(n); - return [ba(l[0], 0, e)]; - } - return e === "bool" ? [Am(n[0])] : [n[0].toString()]; - } - if (c === 1) { - if (o > Im) { - var u = wa * s, h = Array.from(n.slice(0, u)), d = Array.from(n.slice((o - wa) * s, o * s)); - return e === "complex64" && (h = xa(h), d = xa(d)), ["[" + h.map(function(I, C) { - return ba(I, i[C], e); - }).join(", ") + ", ..., " + d.map(function(I, C) { - return ba(I, i[o - wa + C], e); - }).join(", ") + "]"]; - } - var p = e === "complex64" ? xa(n) : Array.from(n); - return ["[" + p.map(function(I, C) { - return ba(I, i[C], e); - }).join(", ") + "]"]; - } - var f = t.slice(1), m = r.slice(1), g = r[0] * s, y = []; - if (o > Im) { - for (var w = 0; w < wa; w++) { - var b = w * g, L = b + g; - y.push.apply(y, Ds(n.slice(b, L), f, e, m, i, false)); - } - y.push("..."); - for (var w = o - wa; w < o; w++) { - var b = w * g, L = b + g; - y.push.apply(y, Ds(n.slice(b, L), f, e, m, i, w === o - 1)); - } - } else - for (var w = 0; w < o; w++) { - var b = w * g, L = b + g; - y.push.apply(y, Ds(n.slice(b, L), f, e, m, i, w === o - 1)); - } - var x = c === 2 ? "," : ""; - y[0] = "[" + y[0] + x; - for (var w = 1; w < y.length - 1; w++) - y[w] = " " + y[w] + x; - for (var N = `, -`, w = 2; w < c; w++) - N += ` -`; - return y[y.length - 1] = " " + y[y.length - 1] + "]" + (a ? "" : N), y; - } - function xa(n) { - for (var t = [], e = 0; e < n.length; e += 2) - t.push([n[e], n[e + 1]]); - return t; - } - var ks = function() { - function n(t, e, r) { - var i = this; - if (this.dtype = e, this.shape = t.slice(), this.size = pt(t), r != null) { - var a = r.length; - E(a === this.size, function() { - return "Length of values '" + a + "' does not match the size " + ("inferred by the shape '" + i.size + "'."); - }); - } - if (e === "complex64") - throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag)."); - this.values = r || df(e, this.size), this.strides = bi(t); - } - return n.prototype.set = function(t) { - for (var e = this, r = [], i = 1; i < arguments.length; i++) - r[i - 1] = arguments[i]; - r.length === 0 && (r = [0]), E(r.length === this.rank, function() { - return "The number of provided coordinates (" + r.length + ") must " + ("match the rank (" + e.rank + ")"); - }); - var a = this.locToIndex(r); - this.values[a] = t; - }, n.prototype.get = function() { - for (var t = [], e = 0; e < arguments.length; e++) - t[e] = arguments[e]; - t.length === 0 && (t = [0]); - for (var r = 0, i = 0, a = t; i < a.length; i++) { - var s = a[i]; - if (s < 0 || s >= this.shape[r]) { - var o = "Requested out of range element at " + t + ". " + (" Buffer shape=" + this.shape); - throw new Error(o); + function computeMaxSizePerColumn(vals, shape, dtype, strides) { + var n = sizeFromShape(shape); + var numCols = strides[strides.length - 1]; + var padPerCol = new Array(numCols).fill(0); + var rank = shape.length; + var valuesOrTuples = dtype === "complex64" ? createComplexTuples(vals) : vals; + if (rank > 1) { + for (var row = 0; row < n / numCols; row++) { + var offset = row * numCols; + for (var j = 0; j < numCols; j++) { + padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); } - r++; } - for (var c = t[t.length - 1], l = 0; l < t.length - 1; ++l) - c += this.strides[l] * t[l]; - return this.values[c]; - }, n.prototype.locToIndex = function(t) { - if (this.rank === 0) - return 0; - if (this.rank === 1) - return t[0]; - for (var e = t[t.length - 1], r = 0; r < t.length - 1; ++r) - e += this.strides[r] * t[r]; - return e; - }, n.prototype.indexToLoc = function(t) { - if (this.rank === 0) - return []; - if (this.rank === 1) - return [t]; - for (var e = new Array(this.shape.length), r = 0; r < e.length - 1; ++r) - e[r] = Math.floor(t / this.strides[r]), t -= e[r] * this.strides[r]; - return e[e.length - 1] = t, e; - }, Object.defineProperty(n.prototype, "rank", {get: function() { - return this.shape.length; - }, enumerable: true, configurable: true}), n.prototype.toTensor = function() { - return Cn().makeTensor(this.values, this.shape, this.dtype); - }, n; - }(), Cn = null, Ii = null; - function KL(n) { - Cn = n; - } - function jL(n) { - Ii = n; - } - var K = function() { - function n(t, e, r, i) { - this.kept = false, this.isDisposedInternal = false, this.shape = t.slice(), this.dtype = e || "float32", this.size = pt(t), this.strides = bi(t), this.dataId = r, this.id = i, this.rankType = this.rank < 5 ? this.rank.toString() : "higher"; } - return Object.defineProperty(n.prototype, "rank", {get: function() { - return this.shape.length; - }, enumerable: true, configurable: true}), n.prototype.buffer = function() { - return pe(this, void 0, void 0, function() { - var t; - return fe(this, function(e) { - switch (e.label) { + return padPerCol; + } + function valToString(val, pad2, dtype) { + var valStr; + if (Array.isArray(val)) { + valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " + (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j"); + } else if (isString(val)) { + valStr = "'" + val + "'"; + } else if (dtype === "bool") { + valStr = boolNumToString(val); + } else { + valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); + } + return rightPad(valStr, pad2); + } + function boolNumToString(v) { + return v === 0 ? "false" : "true"; + } + function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) { + if (isLast === void 0) { + isLast = true; + } + var storagePerElement = dtype === "complex64" ? 2 : 1; + var size = shape[0]; + var rank = shape.length; + if (rank === 0) { + if (dtype === "complex64") { + var complexTuple = createComplexTuples(vals); + return [valToString(complexTuple[0], 0, dtype)]; + } + if (dtype === "bool") { + return [boolNumToString(vals[0])]; + } + return [vals[0].toString()]; + } + if (rank === 1) { + if (size > FORMAT_LIMIT_NUM_VALS) { + var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; + var firstVals = Array.from(vals.slice(0, firstValsSize)); + var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); + if (dtype === "complex64") { + firstVals = createComplexTuples(firstVals); + lastVals = createComplexTuples(lastVals); + } + return [ + "[" + firstVals.map(function(x, i2) { + return valToString(x, padPerCol[i2], dtype); + }).join(", ") + ", ..., " + lastVals.map(function(x, i2) { + return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i2], dtype); + }).join(", ") + "]" + ]; + } + var displayVals = dtype === "complex64" ? createComplexTuples(vals) : Array.from(vals); + return [ + "[" + displayVals.map(function(x, i2) { + return valToString(x, padPerCol[i2], dtype); + }).join(", ") + "]" + ]; + } + var subshape = shape.slice(1); + var substrides = strides.slice(1); + var stride = strides[0] * storagePerElement; + var lines = []; + if (size > FORMAT_LIMIT_NUM_VALS) { + for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false)); + } + lines.push("..."); + for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1)); + } + } else { + for (var i = 0; i < size; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1)); + } + } + var sep = rank === 2 ? "," : ""; + lines[0] = "[" + lines[0] + sep; + for (var i = 1; i < lines.length - 1; i++) { + lines[i] = " " + lines[i] + sep; + } + var newLineSep = ",\n"; + for (var i = 2; i < rank; i++) { + newLineSep += "\n"; + } + lines[lines.length - 1] = " " + lines[lines.length - 1] + "]" + (isLast ? "" : newLineSep); + return lines; + } + function createComplexTuples(vals) { + var complexTuples = []; + for (var i = 0; i < vals.length; i += 2) { + complexTuples.push([vals[i], vals[i + 1]]); + } + return complexTuples; + } + /** + * @license + * Copyright 2017 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var TensorBuffer = function() { + function TensorBuffer2(shape, dtype, values) { + var _this = this; + this.dtype = dtype; + this.shape = shape.slice(); + this.size = sizeFromShape(shape); + if (values != null) { + var n_1 = values.length; + assert(n_1 === this.size, function() { + return "Length of values '" + n_1 + "' does not match the size " + ("inferred by the shape '" + _this.size + "'."); + }); + } + if (dtype === "complex64") { + throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag)."); + } + this.values = values || getArrayFromDType(dtype, this.size); + this.strides = computeStrides(shape); + } + TensorBuffer2.prototype.set = function(value) { + var _this = this; + var locs = []; + for (var _i2 = 1; _i2 < arguments.length; _i2++) { + locs[_i2 - 1] = arguments[_i2]; + } + if (locs.length === 0) { + locs = [0]; + } + assert(locs.length === this.rank, function() { + return "The number of provided coordinates (" + locs.length + ") must " + ("match the rank (" + _this.rank + ")"); + }); + var index = this.locToIndex(locs); + this.values[index] = value; + }; + TensorBuffer2.prototype.get = function() { + var locs = []; + for (var _i2 = 0; _i2 < arguments.length; _i2++) { + locs[_i2] = arguments[_i2]; + } + if (locs.length === 0) { + locs = [0]; + } + var i = 0; + for (var _a = 0, locs_1 = locs; _a < locs_1.length; _a++) { + var loc = locs_1[_a]; + if (loc < 0 || loc >= this.shape[i]) { + var msg = "Requested out of range element at " + locs + ". " + (" Buffer shape=" + this.shape); + throw new Error(msg); + } + i++; + } + var index = locs[locs.length - 1]; + for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) { + index += this.strides[i_1] * locs[i_1]; + } + return this.values[index]; + }; + TensorBuffer2.prototype.locToIndex = function(locs) { + if (this.rank === 0) { + return 0; + } else if (this.rank === 1) { + return locs[0]; + } + var index = locs[locs.length - 1]; + for (var i = 0; i < locs.length - 1; ++i) { + index += this.strides[i] * locs[i]; + } + return index; + }; + TensorBuffer2.prototype.indexToLoc = function(index) { + if (this.rank === 0) { + return []; + } else if (this.rank === 1) { + return [index]; + } + var locs = new Array(this.shape.length); + for (var i = 0; i < locs.length - 1; ++i) { + locs[i] = Math.floor(index / this.strides[i]); + index -= locs[i] * this.strides[i]; + } + locs[locs.length - 1] = index; + return locs; + }; + Object.defineProperty(TensorBuffer2.prototype, "rank", { + get: function() { + return this.shape.length; + }, + enumerable: true, + configurable: true + }); + TensorBuffer2.prototype.toTensor = function() { + return trackerFn().makeTensor(this.values, this.shape, this.dtype); + }; + return TensorBuffer2; + }(); + var trackerFn = null; + var opHandler = null; + function setTensorTracker(fn) { + trackerFn = fn; + } + function setOpHandler(handler) { + opHandler = handler; + } + var Tensor = function() { + function Tensor2(shape, dtype, dataId, id) { + this.kept = false; + this.isDisposedInternal = false; + this.shape = shape.slice(); + this.dtype = dtype || "float32"; + this.size = sizeFromShape(shape); + this.strides = computeStrides(shape); + this.dataId = dataId; + this.id = id; + this.rankType = this.rank < 5 ? this.rank.toString() : "higher"; + } + Object.defineProperty(Tensor2.prototype, "rank", { + get: function() { + return this.shape.length; + }, + enumerable: true, + configurable: true + }); + Tensor2.prototype.buffer = function() { + return __awaiter(this, void 0, void 0, function() { + var vals; + return __generator(this, function(_a) { + switch (_a.label) { case 0: return [4, this.data()]; case 1: - return t = e.sent(), [2, Ii.buffer(this.shape, this.dtype, t)]; + vals = _a.sent(); + return [2, opHandler.buffer(this.shape, this.dtype, vals)]; } }); }); - }, n.prototype.bufferSync = function() { - return Ii.buffer(this.shape, this.dtype, this.dataSync()); - }, n.prototype.array = function() { - return pe(this, void 0, void 0, function() { - var t; - return fe(this, function(e) { - switch (e.label) { + }; + Tensor2.prototype.bufferSync = function() { + return opHandler.buffer(this.shape, this.dtype, this.dataSync()); + }; + Tensor2.prototype.array = function() { + return __awaiter(this, void 0, void 0, function() { + var vals; + return __generator(this, function(_a) { + switch (_a.label) { case 0: return [4, this.data()]; case 1: - return t = e.sent(), [2, xi(this.shape, t)]; + vals = _a.sent(); + return [2, toNestedArray(this.shape, vals)]; } }); }); - }, n.prototype.arraySync = function() { - return xi(this.shape, this.dataSync()); - }, n.prototype.data = function() { - return pe(this, void 0, void 0, function() { - var t, e; - return fe(this, function(r) { - switch (r.label) { + }; + Tensor2.prototype.arraySync = function() { + return toNestedArray(this.shape, this.dataSync()); + }; + Tensor2.prototype.data = function() { + return __awaiter(this, void 0, void 0, function() { + var data, bytes; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - return this.throwIfDisposed(), t = Cn().read(this.dataId), this.dtype === "string" ? [4, t] : [3, 2]; + this.throwIfDisposed(); + data = trackerFn().read(this.dataId); + if (!(this.dtype === "string")) + return [3, 2]; + return [4, data]; case 1: - e = r.sent(); + bytes = _a.sent(); try { - return [2, e.map(function(i) { - return fu(i); + return [2, bytes.map(function(b) { + return decodeString(b); })]; - } catch (i) { + } catch (_b) { throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes()."); } - r.label = 2; + _a.label = 2; case 2: - return [2, t]; + return [2, data]; } }); }); - }, n.prototype.dataSync = function() { + }; + Tensor2.prototype.dataSync = function() { this.throwIfDisposed(); - var t = Cn().readSync(this.dataId); - if (this.dtype === "string") + var data = trackerFn().readSync(this.dataId); + if (this.dtype === "string") { try { - return t.map(function(e) { - return fu(e); + return data.map(function(b) { + return decodeString(b); }); - } catch (e) { + } catch (_a) { throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes()."); } - return t; - }, n.prototype.bytes = function() { - return pe(this, void 0, void 0, function() { - var t; - return fe(this, function(e) { - switch (e.label) { + } + return data; + }; + Tensor2.prototype.bytes = function() { + return __awaiter(this, void 0, void 0, function() { + var data; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - return this.throwIfDisposed(), [4, Cn().read(this.dataId)]; + this.throwIfDisposed(); + return [4, trackerFn().read(this.dataId)]; case 1: - return t = e.sent(), this.dtype === "string" ? [2, t] : [2, new Uint8Array(t.buffer)]; + data = _a.sent(); + if (this.dtype === "string") { + return [2, data]; + } else { + return [2, new Uint8Array(data.buffer)]; + } } }); }); - }, n.prototype.dispose = function() { - if (this.isDisposed) + }; + Tensor2.prototype.dispose = function() { + if (this.isDisposed) { return; - Cn().disposeTensor(this), this.isDisposedInternal = true; - }, Object.defineProperty(n.prototype, "isDisposed", {get: function() { - return this.isDisposedInternal; - }, enumerable: true, configurable: true}), n.prototype.throwIfDisposed = function() { - if (this.isDisposed) + } + trackerFn().disposeTensor(this); + this.isDisposedInternal = true; + }; + Object.defineProperty(Tensor2.prototype, "isDisposed", { + get: function() { + return this.isDisposedInternal; + }, + enumerable: true, + configurable: true + }); + Tensor2.prototype.throwIfDisposed = function() { + if (this.isDisposed) { throw new Error("Tensor is disposed."); - }, n.prototype.print = function(t) { - return t === void 0 && (t = false), Ii.print(this, t); - }, n.prototype.clone = function() { - return this.throwIfDisposed(), Ii.clone(this); - }, n.prototype.toString = function(t) { - t === void 0 && (t = false); - var e = this.dataSync(); - return YL(e, this.shape, this.dtype, t); - }, n.prototype.cast = function(t) { - return this.throwIfDisposed(), Ii.cast(this, t); - }, n.prototype.variable = function(t, e, r) { - return t === void 0 && (t = true), this.throwIfDisposed(), Cn().makeVariable(this, t, e, r); - }, n; + } + }; + Tensor2.prototype.print = function(verbose) { + if (verbose === void 0) { + verbose = false; + } + return opHandler.print(this, verbose); + }; + Tensor2.prototype.clone = function() { + this.throwIfDisposed(); + return opHandler.clone(this); + }; + Tensor2.prototype.toString = function(verbose) { + if (verbose === void 0) { + verbose = false; + } + var vals = this.dataSync(); + return tensorToString(vals, this.shape, this.dtype, verbose); + }; + Tensor2.prototype.cast = function(dtype) { + this.throwIfDisposed(); + return opHandler.cast(this, dtype); + }; + Tensor2.prototype.variable = function(trainable, name, dtype) { + if (trainable === void 0) { + trainable = true; + } + this.throwIfDisposed(); + return trackerFn().makeVariable(this, trainable, name, dtype); + }; + return Tensor2; }(); - Object.defineProperty(K, Symbol.hasInstance, {value: function(n) { - return !!n && n.data != null && n.dataSync != null && n.throwIfDisposed != null; - }}); - var La = function(n) { - qn(t, n); - function t(e, r, i, a) { - var s = n.call(this, e.shape, e.dtype, e.dataId, a) || this; - return s.trainable = r, s.name = i, s; + Object.defineProperty(Tensor, Symbol.hasInstance, { + value: function(instance2) { + return !!instance2 && instance2.data != null && instance2.dataSync != null && instance2.throwIfDisposed != null; } - return t.prototype.assign = function(e) { - if (e.dtype !== this.dtype) - throw new Error("dtype of the new value (" + e.dtype + ") and " + ("previous value (" + this.dtype + ") must match")); - if (!pn(e.shape, this.shape)) - throw new Error("shape of the new value (" + e.shape + ") and " + ("previous value (" + this.shape + ") must match")); - Cn().disposeTensor(this), this.dataId = e.dataId, Cn().incRef(this, null); - }, t.prototype.dispose = function() { - Cn().disposeVariable(this), this.isDisposedInternal = true; - }, t; - }(K); - Object.defineProperty(La, Symbol.hasInstance, {value: function(n) { - return n instanceof K && n.assign != null && n.assign instanceof Function; - }}); - (function(n) { - n.R0 = "R0", n.R1 = "R1", n.R2 = "R2", n.R3 = "R3", n.R4 = "R4", n.R5 = "R5", n.R6 = "R6"; - })(A.Rank || (A.Rank = {})); - var gu; - (function(n) { - n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64"; - })(gu || (gu = {})); - var yu; - (function(n) { - n.float32 = "float32", n.int32 = "int32", n.bool = "bool", n.complex64 = "complex64"; - })(yu || (yu = {})); - var vu; - (function(n) { - n.float32 = "float32", n.int32 = "float32", n.bool = "float32", n.complex64 = "complex64"; - })(vu || (vu = {})); - var wu; - (function(n) { - n.float32 = "complex64", n.int32 = "complex64", n.bool = "complex64", n.complex64 = "complex64"; - })(wu || (wu = {})); - var $L = {float32: vu, int32: gu, bool: yu, complex64: wu}; - function Fs(n, t) { - if (n === "string" || t === "string") { - if (n === "string" && t === "string") + }); + var Variable = function(_super) { + __extends(Variable2, _super); + function Variable2(initialValue, trainable, name, tensorId) { + var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this; + _this.trainable = trainable; + _this.name = name; + return _this; + } + Variable2.prototype.assign = function(newValue) { + if (newValue.dtype !== this.dtype) { + throw new Error("dtype of the new value (" + newValue.dtype + ") and " + ("previous value (" + this.dtype + ") must match")); + } + if (!arraysEqual(newValue.shape, this.shape)) { + throw new Error("shape of the new value (" + newValue.shape + ") and " + ("previous value (" + this.shape + ") must match")); + } + trackerFn().disposeTensor(this); + this.dataId = newValue.dataId; + trackerFn().incRef(this, null); + }; + Variable2.prototype.dispose = function() { + trackerFn().disposeVariable(this); + this.isDisposedInternal = true; + }; + return Variable2; + }(Tensor); + Object.defineProperty(Variable, Symbol.hasInstance, { + value: function(instance2) { + return instance2 instanceof Tensor && instance2.assign != null && instance2.assign instanceof Function; + } + }); + /** + * @license + * Copyright 2017 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + (function(Rank) { + Rank["R0"] = "R0"; + Rank["R1"] = "R1"; + Rank["R2"] = "R2"; + Rank["R3"] = "R3"; + Rank["R4"] = "R4"; + Rank["R5"] = "R5"; + Rank["R6"] = "R6"; + })(exports.Rank || (exports.Rank = {})); + var UpcastInt32AndMap; + (function(UpcastInt32AndMap2) { + UpcastInt32AndMap2["float32"] = "float32"; + UpcastInt32AndMap2["int32"] = "int32"; + UpcastInt32AndMap2["bool"] = "int32"; + UpcastInt32AndMap2["complex64"] = "complex64"; + })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); + var UpcastBoolAndMap; + (function(UpcastBoolAndMap2) { + UpcastBoolAndMap2["float32"] = "float32"; + UpcastBoolAndMap2["int32"] = "int32"; + UpcastBoolAndMap2["bool"] = "bool"; + UpcastBoolAndMap2["complex64"] = "complex64"; + })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); + var UpcastFloat32AndMap; + (function(UpcastFloat32AndMap2) { + UpcastFloat32AndMap2["float32"] = "float32"; + UpcastFloat32AndMap2["int32"] = "float32"; + UpcastFloat32AndMap2["bool"] = "float32"; + UpcastFloat32AndMap2["complex64"] = "complex64"; + })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); + var UpcastComplex64AndMap; + (function(UpcastComplex64AndMap2) { + UpcastComplex64AndMap2["float32"] = "complex64"; + UpcastComplex64AndMap2["int32"] = "complex64"; + UpcastComplex64AndMap2["bool"] = "complex64"; + UpcastComplex64AndMap2["complex64"] = "complex64"; + })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); + var upcastTypeMap = { + float32: UpcastFloat32AndMap, + int32: UpcastInt32AndMap, + bool: UpcastBoolAndMap, + complex64: UpcastComplex64AndMap + }; + function upcastType(typeA, typeB) { + if (typeA === "string" || typeB === "string") { + if (typeA === "string" && typeB === "string") { return "string"; - throw new Error("Can not upcast " + n + " with " + t); + } + throw new Error("Can not upcast " + typeA + " with " + typeB); } - return $L[n][t]; + return upcastTypeMap[typeA][typeB]; } - function XL(n) { - return Fs(n, "int32"); + function sumOutType(type) { + return upcastType(type, "int32"); } - function ct(n, t) { - if (n.dtype === t.dtype) - return [n, t]; - var e = Fs(n.dtype, t.dtype); - return [n.cast(e), t.cast(e)]; + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function makeTypesMatch(a, b) { + if (a.dtype === b.dtype) { + return [a, b]; + } + var dtype = upcastType(a.dtype, b.dtype); + return [a.cast(dtype), b.cast(dtype)]; } - function Tm(n, t) { - E(n.dtype === t.dtype, function() { - return "The dtypes of the first(" + n.dtype + ") and" + (" second(" + t.dtype + ") input must match"); + function assertTypesMatch(a, b) { + assert(a.dtype === b.dtype, function() { + return "The dtypes of the first(" + a.dtype + ") and" + (" second(" + b.dtype + ") input must match"); }); } - function JL(n, t) { - return t.some(function(e) { - return e.id === n.id; + function isTensorInList(tensor2, tensorList) { + return tensorList.some(function(x) { + return x.id === tensor2.id; }); } - function bu(n) { - var t = [], e = new Set(); - return Nm(n, t, e), t; + function getTensorsInContainer(result) { + var list = []; + var seen = new Set(); + walkTensorContainer(result, list, seen); + return list; } - function Nm(n, t, e) { - if (n == null) - return; - if (n instanceof K) { - t.push(n); + function walkTensorContainer(container, list, seen) { + if (container == null) { return; } - if (!ZL(n)) + if (container instanceof Tensor) { + list.push(container); return; - var r = n; - for (var i in r) { - var a = r[i]; - e.has(a) || (e.add(a), Nm(a, t, e)); + } + if (!isIterable(container)) { + return; + } + var iterable = container; + for (var k in iterable) { + var val = iterable[k]; + if (!seen.has(val)) { + seen.add(val); + walkTensorContainer(val, list, seen); + } } } - function ZL(n) { - return Array.isArray(n) || typeof n == "object"; + function isIterable(obj) { + return Array.isArray(obj) || typeof obj === "object"; } - var QL = {__proto__: null, makeTypesMatch: ct, assertTypesMatch: Tm, isTensorInList: JL, getTensorsInContainer: bu}; - var _m = function() { - function n() { - this.registeredVariables = {}, this.nextTapeNodeId = 0, this.numBytes = 0, this.numTensors = 0, this.numStringTensors = 0, this.numDataBuffers = 0, this.gradientDepth = 0, this.kernelDepth = 0, this.scopeStack = [], this.numDataMovesStack = [], this.nextScopeId = 0, this.tensorInfo = new WeakMap(), this.profiling = false, this.activeProfile = {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null}; + var tensor_util = { + __proto__: null, + makeTypesMatch, + assertTypesMatch, + isTensorInList, + getTensorsInContainer + }; + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var EngineState = function() { + function EngineState2() { + this.registeredVariables = {}; + this.nextTapeNodeId = 0; + this.numBytes = 0; + this.numTensors = 0; + this.numStringTensors = 0; + this.numDataBuffers = 0; + this.gradientDepth = 0; + this.kernelDepth = 0; + this.scopeStack = []; + this.numDataMovesStack = []; + this.nextScopeId = 0; + this.tensorInfo = new WeakMap(); + this.profiling = false; + this.activeProfile = {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null}; } - return n.prototype.dispose = function() { - for (var t in this.registeredVariables) - this.registeredVariables[t].dispose(); - }, n; - }(), nS = function() { - function n(t) { - this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new _m(); + EngineState2.prototype.dispose = function() { + for (var variableName in this.registeredVariables) { + this.registeredVariables[variableName].dispose(); + } + }; + return EngineState2; + }(); + var Engine = function() { + function Engine2(ENV2) { + this.ENV = ENV2; + this.registry = {}; + this.registryFactory = {}; + this.pendingBackendInitId = 0; + this.state = new EngineState(); } - return n.prototype.ready = function() { - return pe(this, void 0, void 0, function() { - var t, e, r, i; - return fe(this, function(a) { - switch (a.label) { + Engine2.prototype.ready = function() { + return __awaiter(this, void 0, void 0, function() { + var sortedBackends, i, backendName, success; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - if (this.pendingBackendInit != null) + if (this.pendingBackendInit != null) { return [2, this.pendingBackendInit.then(function() { })]; - if (this.backendInstance != null) + } + if (this.backendInstance != null) { return [2]; - t = this.getSortedBackends(), e = 0, a.label = 1; + } + sortedBackends = this.getSortedBackends(); + i = 0; + _a.label = 1; case 1: - return e < t.length ? (r = t[e], [4, this.initializeBackend(r).success]) : [3, 5]; + if (!(i < sortedBackends.length)) + return [3, 5]; + backendName = sortedBackends[i]; + return [4, this.initializeBackend(backendName).success]; case 2: - return i = a.sent(), i ? [4, this.setBackend(r)] : [3, 4]; + success = _a.sent(); + if (!success) + return [3, 4]; + return [4, this.setBackend(backendName)]; case 3: - return a.sent(), [2]; + _a.sent(); + return [2]; case 4: - return e++, [3, 1]; + i++; + return [3, 1]; case 5: throw new Error("Could not initialize any backends, all backend initializations failed."); } }); }); - }, Object.defineProperty(n.prototype, "backend", {get: function() { - if (this.pendingBackendInit != null) - throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods"); - if (this.backendInstance == null) { - var t = this.initializeBackendsAndReturnBest(), e = t.name, r = t.asyncInit; - if (r) - throw new Error("The highest priority backend '" + e + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods"); - this.setBackend(e); - } - return this.backendInstance; - }, enumerable: true, configurable: true}), n.prototype.backendNames = function() { + }; + Object.defineProperty(Engine2.prototype, "backend", { + get: function() { + if (this.pendingBackendInit != null) { + throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods"); + } + if (this.backendInstance == null) { + var _a = this.initializeBackendsAndReturnBest(), name_1 = _a.name, asyncInit = _a.asyncInit; + if (asyncInit) { + throw new Error("The highest priority backend '" + name_1 + "' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods"); + } + this.setBackend(name_1); + } + return this.backendInstance; + }, + enumerable: true, + configurable: true + }); + Engine2.prototype.backendNames = function() { return Object.keys(this.registryFactory); - }, n.prototype.findBackend = function(t) { - if (!(t in this.registry)) - if (t in this.registryFactory) { - var e = this.initializeBackend(t).asyncInit; - if (e) + }; + Engine2.prototype.findBackend = function(backendName) { + if (!(backendName in this.registry)) { + if (backendName in this.registryFactory) { + var asyncInit = this.initializeBackend(backendName).asyncInit; + if (asyncInit) { return null; - } else + } + } else { return null; - return this.registry[t]; - }, n.prototype.findBackendFactory = function(t) { - return t in this.registryFactory ? this.registryFactory[t].factory : null; - }, n.prototype.registerBackend = function(t, e, r) { - return r === void 0 && (r = 1), t in this.registryFactory ? (console.warn(t + " backend was already registered. Reusing existing backend factory."), false) : (this.registryFactory[t] = {factory: e, priority: r}, true); - }, n.prototype.setBackend = function(t) { - return pe(this, void 0, void 0, function() { - var e, r, i, a, s; - return fe(this, function(o) { - switch (o.label) { + } + } + return this.registry[backendName]; + }; + Engine2.prototype.findBackendFactory = function(backendName) { + if (!(backendName in this.registryFactory)) { + return null; + } + return this.registryFactory[backendName].factory; + }; + Engine2.prototype.registerBackend = function(backendName, factory, priority) { + if (priority === void 0) { + priority = 1; + } + if (backendName in this.registryFactory) { + console.warn(backendName + " backend was already registered. Reusing existing backend factory."); + return false; + } + this.registryFactory[backendName] = {factory, priority}; + return true; + }; + Engine2.prototype.setBackend = function(backendName) { + return __awaiter(this, void 0, void 0, function() { + var _a, success, asyncInit, result, _b; + return __generator(this, function(_c) { + switch (_c.label) { case 0: - if (this.registryFactory[t] == null) - throw new Error("Backend name '" + t + "' not found in registry"); - return this.backendName = t, this.registry[t] == null ? (this.backendInstance = null, e = this.initializeBackend(t), r = e.success, i = e.asyncInit, i ? [4, r] : [3, 2]) : [3, 4]; + if (this.registryFactory[backendName] == null) { + throw new Error("Backend name '" + backendName + "' not found in registry"); + } + this.backendName = backendName; + if (!(this.registry[backendName] == null)) + return [3, 4]; + this.backendInstance = null; + _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; + if (!asyncInit) + return [3, 2]; + return [4, success]; case 1: - return s = o.sent(), [3, 3]; + _b = _c.sent(); + return [3, 3]; case 2: - s = r, o.label = 3; + _b = success; + _c.label = 3; case 3: - if (a = s, !a) + result = _b; + if (!result) { return [2, false]; - o.label = 4; + } + _c.label = 4; case 4: - return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new HL(this.backendInstance), [2, true]; + this.backendInstance = this.registry[backendName]; + this.setupRegisteredKernels(); + this.profiler = new Profiler(this.backendInstance); + return [2, true]; } }); }); - }, n.prototype.setupRegisteredKernels = function() { - var t = this, e = Os(this.backendName); - e.forEach(function(r) { - r.setupFunc != null && r.setupFunc(t.backendInstance); + }; + Engine2.prototype.setupRegisteredKernels = function() { + var _this = this; + var kernels = getKernelsForBackend(this.backendName); + kernels.forEach(function(kernel) { + if (kernel.setupFunc != null) { + kernel.setupFunc(_this.backendInstance); + } }); - }, n.prototype.disposeRegisteredKernels = function(t) { - var e = this, r = Os(t); - r.forEach(function(i) { - i.disposeFunc != null && i.disposeFunc(e.registry[t]); + }; + Engine2.prototype.disposeRegisteredKernels = function(backendName) { + var _this = this; + var kernels = getKernelsForBackend(backendName); + kernels.forEach(function(kernel) { + if (kernel.disposeFunc != null) { + kernel.disposeFunc(_this.registry[backendName]); + } }); - }, n.prototype.initializeBackend = function(t) { - var e = this, r = this.registryFactory[t]; - if (r == null) - throw new Error("Cannot initialize backend " + t + ", no registration found."); - try { - var i = r.factory(); - if (i && !(i instanceof cf) && typeof i.then == "function") { - var a = ++this.pendingBackendInitId, s = i.then(function(o) { - return a < e.pendingBackendInitId ? false : (e.registry[t] = o, e.pendingBackendInit = null, true); - }).catch(function(o) { - return a < e.pendingBackendInitId || (e.pendingBackendInit = null, console.warn("Initialization of backend " + t + " failed"), console.warn(o.stack || o.message)), false; - }); - return this.pendingBackendInit = s, {success: s, asyncInit: true}; - } else - return this.registry[t] = i, {success: true, asyncInit: false}; - } catch (o) { - return console.warn("Initialization of backend " + t + " failed"), console.warn(o.stack || o.message), {success: false, asyncInit: false}; + }; + Engine2.prototype.initializeBackend = function(backendName) { + var _this = this; + var registryFactoryEntry = this.registryFactory[backendName]; + if (registryFactoryEntry == null) { + throw new Error("Cannot initialize backend " + backendName + ", no registration found."); } - }, n.prototype.removeBackend = function(t) { - if (!(t in this.registryFactory)) - throw new Error(t + " backend not found in registry"); - this.backendName === t && this.pendingBackendInit != null && this.pendingBackendInitId++, t in this.registry && (this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t]), delete this.registryFactory[t], this.backendName === t && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null); - }, n.prototype.getSortedBackends = function() { - var t = this; - if (Object.keys(this.registryFactory).length === 0) + try { + var backend2 = registryFactoryEntry.factory(); + if (backend2 && !(backend2 instanceof KernelBackend) && typeof backend2.then === "function") { + var promiseId_1 = ++this.pendingBackendInitId; + var success = backend2.then(function(backendInstance) { + if (promiseId_1 < _this.pendingBackendInitId) { + return false; + } + _this.registry[backendName] = backendInstance; + _this.pendingBackendInit = null; + return true; + }).catch(function(err) { + if (promiseId_1 < _this.pendingBackendInitId) { + return false; + } + _this.pendingBackendInit = null; + console.warn("Initialization of backend " + backendName + " failed"); + console.warn(err.stack || err.message); + return false; + }); + this.pendingBackendInit = success; + return {success, asyncInit: true}; + } else { + this.registry[backendName] = backend2; + return {success: true, asyncInit: false}; + } + } catch (err) { + console.warn("Initialization of backend " + backendName + " failed"); + console.warn(err.stack || err.message); + return {success: false, asyncInit: false}; + } + }; + Engine2.prototype.removeBackend = function(backendName) { + if (!(backendName in this.registryFactory)) { + throw new Error(backendName + " backend not found in registry"); + } + if (this.backendName === backendName && this.pendingBackendInit != null) { + this.pendingBackendInitId++; + } + if (backendName in this.registry) { + this.disposeRegisteredKernels(backendName); + this.registry[backendName].dispose(); + delete this.registry[backendName]; + } + delete this.registryFactory[backendName]; + if (this.backendName === backendName) { + this.pendingBackendInit = null; + this.backendName = null; + this.backendInstance = null; + } + }; + Engine2.prototype.getSortedBackends = function() { + var _this = this; + if (Object.keys(this.registryFactory).length === 0) { throw new Error("No backend found in registry."); - return Object.keys(this.registryFactory).sort(function(e, r) { - return t.registryFactory[r].priority - t.registryFactory[e].priority; + } + return Object.keys(this.registryFactory).sort(function(a, b) { + return _this.registryFactory[b].priority - _this.registryFactory[a].priority; }); - }, n.prototype.initializeBackendsAndReturnBest = function() { - for (var t = this.getSortedBackends(), e = 0; e < t.length; e++) { - var r = t[e], i = this.initializeBackend(r), a = i.success, s = i.asyncInit; - if (s || a) - return {name: r, asyncInit: s}; + }; + Engine2.prototype.initializeBackendsAndReturnBest = function() { + var sortedBackends = this.getSortedBackends(); + for (var i = 0; i < sortedBackends.length; i++) { + var backendName = sortedBackends[i]; + var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; + if (asyncInit || success) { + return {name: backendName, asyncInit}; + } } throw new Error("Could not initialize any backends, all backend initializations failed."); - }, n.prototype.moveData = function(t, e) { - var r = this.state.tensorInfo.get(e), i = r.backend, a = this.readSync(e); - i.disposeData(e), r.backend = t, t.move(e, a, r.shape, r.dtype), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; - }, n.prototype.tidy = function(t, e) { - var r = this, i = null; - if (e == null) { - if (typeof t != "function") + }; + Engine2.prototype.moveData = function(backend2, dataId) { + var info = this.state.tensorInfo.get(dataId); + var srcBackend = info.backend; + var values = this.readSync(dataId); + srcBackend.disposeData(dataId); + info.backend = backend2; + backend2.move(dataId, values, info.shape, info.dtype); + if (this.shouldCheckForMemLeaks()) { + this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; + } + }; + Engine2.prototype.tidy = function(nameOrFn, fn) { + var _this = this; + var name = null; + if (fn == null) { + if (typeof nameOrFn !== "function") { throw new Error("Please provide a function to tidy()"); - e = t; + } + fn = nameOrFn; } else { - if (typeof t != "string" && !(t instanceof String)) + if (typeof nameOrFn !== "string" && !(nameOrFn instanceof String)) { throw new Error("When calling with two arguments, the first argument to tidy() must be a string"); - if (typeof e != "function") + } + if (typeof fn !== "function") { throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function"); - i = t; + } + name = nameOrFn; } - var a; + var result; return this.scopedRun(function() { - return r.startScope(i); + return _this.startScope(name); }, function() { - return r.endScope(a); + return _this.endScope(result); }, function() { - return a = e(), a instanceof Promise && console.error("Cannot return a Promise inside of tidy."), a; + result = fn(); + if (result instanceof Promise) { + console.error("Cannot return a Promise inside of tidy."); + } + return result; }); - }, n.prototype.scopedRun = function(t, e, r) { - t(); + }; + Engine2.prototype.scopedRun = function(start, end, f) { + start(); try { - var i = r(); - return e(), i; - } catch (a) { - throw e(), a; + var res = f(); + end(); + return res; + } catch (ex) { + end(); + throw ex; } - }, n.prototype.nextTensorId = function() { - return n.nextTensorId++; - }, n.prototype.nextVariableId = function() { - return n.nextVariableId++; - }, n.prototype.clone = function(t) { - var e = this.makeTensorFromDataId(t.dataId, t.shape, t.dtype), r = {x: t}, i = function(s) { - return {x: function() { - var o = "float32", c = {x: s}, l = {dtype: o}; - return z.runKernelFunc(function(u) { - return u.cast(s, o); - }, c, null, Rs, l); - }}; - }, a = []; - return this.addTapeNode(this.state.activeScope.name, r, [e], i, a, {}), e; - }, n.prototype.runKernel = function(t, e, r, i, a) { - var s = null, o = null; - return this.runKernelFunc(s, e, o, t, r, i, a); - }, n.prototype.shouldCheckForMemLeaks = function() { - return this.ENV.getBool("IS_TEST"); - }, n.prototype.checkKernelForMemLeak = function(t, e, r) { - var i = this.backend.numDataIds(), a = 0; - r.forEach(function(c) { - a += c.dtype === "complex64" ? 3 : 1; - }); - var s = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], o = i - e - a - s; - if (o > 0) - throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + ("(" + o + " data ids) after running '" + t + "'")); - }, n.prototype.runKernelFunc = function(t, e, r, i, a, s, o) { - var c = this, l, u = [], h = this.isTapeOn(); - i == null && (i = this.state.activeScope != null ? this.state.activeScope.name : ""); - var d = this.state.numBytes, p = this.state.numTensors; - this.shouldCheckForMemLeaks() && this.state.numDataMovesStack.push(0); - var f, m = uu(i, this.backendName), g; - if (m != null) - f = function() { - var b = c.backend.numDataIds(); - g = m.kernelFunc({inputs: e, attrs: a, backend: c.backend}); - var L = Array.isArray(g) ? g : [g]; - c.shouldCheckForMemLeaks() && c.checkKernelForMemLeak(i, b, L); - var x = L.map(function(C) { - var O = C.dataId, D = C.shape, F = C.dtype; - return c.makeTensorFromDataId(O, D, F); - }); - if (h) { - var N = c.getTensorsForGradient(i, e, x); - if (N == null) { - o == null && (o = []); - var I = x.filter(function(C, O) { - return o[O]; - }); - N = (s || []).slice().concat(I); - } - u = c.saveTensorsForBackwardMode(N); + }; + Engine2.prototype.nextTensorId = function() { + return Engine2.nextTensorId++; + }; + Engine2.prototype.nextVariableId = function() { + return Engine2.nextVariableId++; + }; + Engine2.prototype.clone = function(x) { + var y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype); + var inputs = {x}; + var grad2 = function(dy) { + return { + x: function() { + var dtype = "float32"; + var gradInputs = {x: dy}; + var attrs = {dtype}; + return ENGINE.runKernelFunc(function(backend2) { + return backend2.cast(dy, dtype); + }, gradInputs, null, Cast, attrs); } - return x; }; - else { - var y = function(b) { - if (!h) + }; + var saved = []; + this.addTapeNode(this.state.activeScope.name, inputs, [y], grad2, saved, {}); + return y; + }; + Engine2.prototype.runKernel = function(kernelName, inputs, attrs, inputsToSave, outputsToSave) { + var forwardFunc = null; + var backwardsFunc = null; + return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave); + }; + Engine2.prototype.shouldCheckForMemLeaks = function() { + return this.ENV.getBool("IS_TEST"); + }; + Engine2.prototype.checkKernelForMemLeak = function(kernelName, numDataIdsBefore, outInfos) { + var numDataIdsAfter = this.backend.numDataIds(); + var numOutputDataIds = 0; + outInfos.forEach(function(info) { + numOutputDataIds += info.dtype === "complex64" ? 3 : 1; + }); + var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; + var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; + if (dataIdsLeaked > 0) { + throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'")); + } + }; + Engine2.prototype.runKernelFunc = function(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) { + var _this = this; + var outputs; + var saved = []; + var isTapeOn = this.isTapeOn(); + if (kernelName == null) { + kernelName = this.state.activeScope != null ? this.state.activeScope.name : ""; + } + var startingBytecount = this.state.numBytes; + var startingNumTensors = this.state.numTensors; + if (this.shouldCheckForMemLeaks()) { + this.state.numDataMovesStack.push(0); + } + var kernelFunc; + var kernel = getKernel(kernelName, this.backendName); + var out; + if (kernel != null) { + kernelFunc = function() { + var numDataIdsBefore = _this.backend.numDataIds(); + out = kernel.kernelFunc({inputs, attrs, backend: _this.backend}); + var outInfos = Array.isArray(out) ? out : [out]; + if (_this.shouldCheckForMemLeaks()) { + _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); + } + var outTensors = outInfos.map(function(_a) { + var dataId = _a.dataId, shape = _a.shape, dtype = _a.dtype; + return _this.makeTensorFromDataId(dataId, shape, dtype); + }); + if (isTapeOn) { + var tensorsToSave = _this.getTensorsForGradient(kernelName, inputs, outTensors); + if (tensorsToSave == null) { + if (outputsToSave == null) { + outputsToSave = []; + } + var outsToSave = outTensors.filter(function(_, i) { + return outputsToSave[i]; + }); + tensorsToSave = (inputsToSave || []).slice().concat(outsToSave); + } + saved = _this.saveTensorsForBackwardMode(tensorsToSave); + } + return outTensors; + }; + } else { + var saveFunc_1 = function(tensors) { + if (!isTapeOn) { return; - u = b.map(function(L) { - return c.keep(c.clone(L)); + } + saved = tensors.map(function(tensor2) { + return _this.keep(_this.clone(tensor2)); }); }; - f = function() { - var b = c.backend.numDataIds(); - g = c.tidy(function() { - return t(c.backend, y); + kernelFunc = function() { + var numDataIdsBefore = _this.backend.numDataIds(); + out = _this.tidy(function() { + return forwardFunc(_this.backend, saveFunc_1); }); - var L = Array.isArray(g) ? g : [g]; - return c.shouldCheckForMemLeaks() && c.checkKernelForMemLeak(i, b, L), L; + var outs = Array.isArray(out) ? out : [out]; + if (_this.shouldCheckForMemLeaks()) { + _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs); + } + return outs; }; } - var w; - return this.scopedRun(function() { - return c.state.kernelDepth++; + var kernelProfile; + this.scopedRun(function() { + return _this.state.kernelDepth++; }, function() { - return c.state.kernelDepth--; + return _this.state.kernelDepth--; }, function() { - !c.ENV.getBool("DEBUG") && !c.state.profiling ? l = f() : (w = c.profiler.profileKernel(i, e, function() { - return f(); - }), c.ENV.getBool("DEBUG") && c.profiler.logKernelProfile(w), l = w.outputs); - }), h && this.addTapeNode(i, e, l, r, u, a), this.state.profiling && this.state.activeProfile.kernels.push({name: i, bytesAdded: this.state.numBytes - d, totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - p, totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(e).map(function(b) { - return e[b] != null ? e[b].shape : null; - }), outputShapes: l.map(function(b) { - return b.shape; - }), kernelTimeMs: w.timeMs, extraInfo: w.extraInfo}), Array.isArray(g) ? l : l[0]; - }, n.prototype.saveTensorsForBackwardMode = function(t) { - var e = this, r = t.map(function(i) { - return e.keep(e.clone(i)); + if (!_this.ENV.getBool("DEBUG") && !_this.state.profiling) { + outputs = kernelFunc(); + } else { + kernelProfile = _this.profiler.profileKernel(kernelName, inputs, function() { + return kernelFunc(); + }); + if (_this.ENV.getBool("DEBUG")) { + _this.profiler.logKernelProfile(kernelProfile); + } + outputs = kernelProfile.outputs; + } }); - return r; - }, n.prototype.getTensorsForGradient = function(t, e, r) { - var i = hu(t); - if (i != null) { - var a = i.inputsToSave || [], s = i.outputsToSave || [], o = void 0; - i.saveAllInputs ? (E(Array.isArray(e), function() { - return "saveAllInputs is true, expected inputs to be an array."; - }), o = Object.keys(e).map(function(l) { - return e[l]; - })) : o = a.map(function(l) { - return e[l]; + if (isTapeOn) { + this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs); + } + if (this.state.profiling) { + this.state.activeProfile.kernels.push({ + name: kernelName, + bytesAdded: this.state.numBytes - startingBytecount, + totalBytesSnapshot: this.state.numBytes, + tensorsAdded: this.state.numTensors - startingNumTensors, + totalTensorsSnapshot: this.state.numTensors, + inputShapes: Object.keys(inputs).map(function(key) { + return inputs[key] != null ? inputs[key].shape : null; + }), + outputShapes: outputs.map(function(item) { + return item.shape; + }), + kernelTimeMs: kernelProfile.timeMs, + extraInfo: kernelProfile.extraInfo }); - var c = r.filter(function(l, u) { - return s[u]; + } + return Array.isArray(out) ? outputs : outputs[0]; + }; + Engine2.prototype.saveTensorsForBackwardMode = function(tensors) { + var _this = this; + var saved = tensors.map(function(tensor2) { + return _this.keep(_this.clone(tensor2)); + }); + return saved; + }; + Engine2.prototype.getTensorsForGradient = function(kernelName, inputs, outputs) { + var gradConfig = getGradient(kernelName); + if (gradConfig != null) { + var inputsToSave = gradConfig.inputsToSave || []; + var outputsToSave_1 = gradConfig.outputsToSave || []; + var inputTensorsToSave = void 0; + if (gradConfig.saveAllInputs) { + assert(Array.isArray(inputs), function() { + return "saveAllInputs is true, expected inputs to be an array."; + }); + inputTensorsToSave = Object.keys(inputs).map(function(key) { + return inputs[key]; + }); + } else { + inputTensorsToSave = inputsToSave.map(function(inputName) { + return inputs[inputName]; + }); + } + var outputTensorsToSave = outputs.filter(function(_, i) { + return outputsToSave_1[i]; }); - return o.concat(c); + return inputTensorsToSave.concat(outputTensorsToSave); } return null; - }, n.prototype.makeTensor = function(t, e, r, i) { - if (t == null) + }; + Engine2.prototype.makeTensor = function(values, shape, dtype, backend2) { + if (values == null) { throw new Error("Values passed to engine.makeTensor() are null"); - r = r || "float32", i = i || this.backend; - var a = t; - r === "string" && or(t[0]) && (a = t.map(function(u) { - return du(u); - })); - var s = i.write(a, e, r), o = new K(e, r, s, this.nextTensorId()); - if (this.incRef(o, i), r === "string") { - var c = this.state.tensorInfo.get(s), l = yf(a); - this.state.numBytes += l - c.bytes, c.bytes = l; } - return o; - }, n.prototype.makeTensorFromDataId = function(t, e, r, i) { - r = r || "float32"; - var a = new K(e, r, t, this.nextTensorId()); - return this.incRef(a, i), a; - }, n.prototype.makeVariable = function(t, e, r, i) { - e === void 0 && (e = true), r = r || this.nextVariableId().toString(), i != null && i !== t.dtype && (t = t.cast(i)); - var a = new La(t, e, r, this.nextTensorId()); - if (this.state.registeredVariables[a.name] != null) - throw new Error("Variable with name " + a.name + " was already registered"); - return this.state.registeredVariables[a.name] = a, this.incRef(a, this.backend), a; - }, n.prototype.incRef = function(t, e) { - var r = this.state.tensorInfo.has(t.dataId) ? this.state.tensorInfo.get(t.dataId).refCount : 0; - if (this.state.numTensors++, t.dtype === "string" && this.state.numStringTensors++, r === 0) { + dtype = dtype || "float32"; + backend2 = backend2 || this.backend; + var backendVals = values; + if (dtype === "string" && isString(values[0])) { + backendVals = values.map(function(d) { + return encodeString(d); + }); + } + var dataId = backend2.write(backendVals, shape, dtype); + var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); + this.incRef(t, backend2); + if (dtype === "string") { + var info = this.state.tensorInfo.get(dataId); + var newBytes = bytesFromStringArray(backendVals); + this.state.numBytes += newBytes - info.bytes; + info.bytes = newBytes; + } + return t; + }; + Engine2.prototype.makeTensorFromDataId = function(dataId, shape, dtype, backend2) { + dtype = dtype || "float32"; + var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); + this.incRef(t, backend2); + return t; + }; + Engine2.prototype.makeVariable = function(initialValue, trainable, name, dtype) { + if (trainable === void 0) { + trainable = true; + } + name = name || this.nextVariableId().toString(); + if (dtype != null && dtype !== initialValue.dtype) { + initialValue = initialValue.cast(dtype); + } + var v = new Variable(initialValue, trainable, name, this.nextTensorId()); + if (this.state.registeredVariables[v.name] != null) { + throw new Error("Variable with name " + v.name + " was already registered"); + } + this.state.registeredVariables[v.name] = v; + this.incRef(v, this.backend); + return v; + }; + Engine2.prototype.incRef = function(a, backend2) { + var refCount = this.state.tensorInfo.has(a.dataId) ? this.state.tensorInfo.get(a.dataId).refCount : 0; + this.state.numTensors++; + if (a.dtype === "string") { + this.state.numStringTensors++; + } + if (refCount === 0) { this.state.numDataBuffers++; - var i = 0; - t.dtype !== "complex64" && t.dtype !== "string" && (i = t.size * gf(t.dtype)), this.state.tensorInfo.set(t.dataId, {backend: e || this.backend, dtype: t.dtype, shape: t.shape, bytes: i, refCount: 0}), this.state.numBytes += i; + var bytes = 0; + if (a.dtype !== "complex64" && a.dtype !== "string") { + bytes = a.size * bytesPerElement(a.dtype); + } + this.state.tensorInfo.set(a.dataId, { + backend: backend2 || this.backend, + dtype: a.dtype, + shape: a.shape, + bytes, + refCount: 0 + }); + this.state.numBytes += bytes; } - this.state.tensorInfo.get(t.dataId).refCount++, t instanceof La || this.track(t); - }, n.prototype.disposeTensor = function(t) { - if (!this.state.tensorInfo.has(t.dataId)) + this.state.tensorInfo.get(a.dataId).refCount++; + if (!(a instanceof Variable)) { + this.track(a); + } + }; + Engine2.prototype.disposeTensor = function(a) { + if (!this.state.tensorInfo.has(a.dataId)) { return; - this.state.numTensors--, t.dtype === "string" && this.state.numStringTensors--; - var e = this.state.tensorInfo.get(t.dataId), r = e.refCount; - r <= 1 ? (t.dtype !== "complex64" && (this.state.numBytes -= e.bytes), this.state.numDataBuffers--, e.backend.disposeData(t.dataId), this.state.tensorInfo.delete(t.dataId)) : this.state.tensorInfo.get(t.dataId).refCount--; - }, n.prototype.disposeVariables = function() { - for (var t in this.state.registeredVariables) { - var e = this.state.registeredVariables[t]; - this.disposeVariable(e); } - }, n.prototype.disposeVariable = function(t) { - this.disposeTensor(t), this.state.registeredVariables[t.name] != null && delete this.state.registeredVariables[t.name]; - }, n.prototype.memory = function() { - var t = this.backend.memory(); - return t.numTensors = this.state.numTensors, t.numDataBuffers = this.state.numDataBuffers, t.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (t.unreliable = true, t.reasons == null && (t.reasons = []), t.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), t; - }, n.prototype.profile = function(t) { - return pe(this, void 0, void 0, function() { - var e, r, i, a, s, o, c, l; - return fe(this, function(u) { - switch (u.label) { + this.state.numTensors--; + if (a.dtype === "string") { + this.state.numStringTensors--; + } + var info = this.state.tensorInfo.get(a.dataId); + var refCount = info.refCount; + if (refCount <= 1) { + if (a.dtype !== "complex64") { + this.state.numBytes -= info.bytes; + } + this.state.numDataBuffers--; + info.backend.disposeData(a.dataId); + this.state.tensorInfo.delete(a.dataId); + } else { + this.state.tensorInfo.get(a.dataId).refCount--; + } + }; + Engine2.prototype.disposeVariables = function() { + for (var varName in this.state.registeredVariables) { + var v = this.state.registeredVariables[varName]; + this.disposeVariable(v); + } + }; + Engine2.prototype.disposeVariable = function(v) { + this.disposeTensor(v); + if (this.state.registeredVariables[v.name] != null) { + delete this.state.registeredVariables[v.name]; + } + }; + Engine2.prototype.memory = function() { + var info = this.backend.memory(); + info.numTensors = this.state.numTensors; + info.numDataBuffers = this.state.numDataBuffers; + info.numBytes = this.state.numBytes; + if (this.state.numStringTensors > 0) { + info.unreliable = true; + if (info.reasons == null) { + info.reasons = []; + } + info.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)"); + } + return info; + }; + Engine2.prototype.profile = function(query) { + return __awaiter(this, void 0, void 0, function() { + var startBytes, startNumTensors, _a, _i2, _b, kernel, _c, _d; + return __generator(this, function(_e) { + switch (_e.label) { case 0: - return this.state.profiling = true, e = this.state.numBytes, r = this.state.numTensors, this.state.activeProfile.kernels = [], i = this.state.activeProfile, [4, t()]; + this.state.profiling = true; + startBytes = this.state.numBytes; + startNumTensors = this.state.numTensors; + this.state.activeProfile.kernels = []; + _a = this.state.activeProfile; + return [4, query()]; case 1: - i.result = u.sent(), this.state.profiling = false, this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function(h) { - return h.totalBytesSnapshot; - })), this.state.activeProfile.newBytes = this.state.numBytes - e, this.state.activeProfile.newTensors = this.state.numTensors - r, a = 0, s = this.state.activeProfile.kernels, u.label = 2; + _a.result = _e.sent(); + this.state.profiling = false; + this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function(d) { + return d.totalBytesSnapshot; + })); + this.state.activeProfile.newBytes = this.state.numBytes - startBytes; + this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors; + _i2 = 0, _b = this.state.activeProfile.kernels; + _e.label = 2; case 2: - return a < s.length ? (o = s[a], c = o, [4, o.kernelTimeMs]) : [3, 6]; + if (!(_i2 < _b.length)) + return [3, 6]; + kernel = _b[_i2]; + _c = kernel; + return [4, kernel.kernelTimeMs]; case 3: - return c.kernelTimeMs = u.sent(), l = o, [4, o.extraInfo]; + _c.kernelTimeMs = _e.sent(); + _d = kernel; + return [4, kernel.extraInfo]; case 4: - l.extraInfo = u.sent(), u.label = 5; + _d.extraInfo = _e.sent(); + _e.label = 5; case 5: - return a++, [3, 2]; + _i2++; + return [3, 2]; case 6: return [2, this.state.activeProfile]; } }); }); - }, n.prototype.isTapeOn = function() { + }; + Engine2.prototype.isTapeOn = function() { return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; - }, n.prototype.addTapeNode = function(t, e, r, i, a, s) { - var o = this, c = {id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: r, saved: a}, l = hu(t); - l != null && (i = l.gradFunc), i != null && (c.gradient = function(u) { - return u = u.map(function(h, d) { - if (h == null) { - var p = r[d], f = Li(p.size, p.dtype); - return o.makeTensor(f, p.shape, p.dtype); - } - return h; - }), i(u.length > 1 ? u : u[0], a, s); - }), this.state.activeTape.push(c); - }, n.prototype.keep = function(t) { - return t.kept = true, t; - }, n.prototype.startTape = function() { - this.state.gradientDepth === 0 && (this.state.activeTape = []), this.state.gradientDepth++; - }, n.prototype.endTape = function() { - this.state.gradientDepth--; - }, n.prototype.startScope = function(t) { - var e = {track: [], name: "unnamed scope", id: this.state.nextScopeId++}; - t && (e.name = t), this.state.scopeStack.push(e), this.state.activeScope = e; - }, n.prototype.endScope = function(t) { - for (var e = this, r = bu(t), i = new Set(r.map(function(c) { - return c.id; - })), a = 0; a < this.state.activeScope.track.length; a++) { - var s = this.state.activeScope.track[a]; - !s.kept && !i.has(s.id) && s.dispose(); + }; + Engine2.prototype.addTapeNode = function(kernelName, inputs, outputs, gradientsFunc, saved, attrs) { + var _this = this; + var tapeNode = {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved}; + var gradConfig = getGradient(kernelName); + if (gradConfig != null) { + gradientsFunc = gradConfig.gradFunc; } - var o = this.state.scopeStack.pop(); - this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], r.forEach(function(c) { - !c.kept && c.scopeId === o.id && e.track(c); + if (gradientsFunc != null) { + tapeNode.gradient = function(dys) { + dys = dys.map(function(dy, i) { + if (dy == null) { + var output = outputs[i]; + var vals = makeZerosTypedArray(output.size, output.dtype); + return _this.makeTensor(vals, output.shape, output.dtype); + } + return dy; + }); + return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); + }; + } + this.state.activeTape.push(tapeNode); + }; + Engine2.prototype.keep = function(result) { + result.kept = true; + return result; + }; + Engine2.prototype.startTape = function() { + if (this.state.gradientDepth === 0) { + this.state.activeTape = []; + } + this.state.gradientDepth++; + }; + Engine2.prototype.endTape = function() { + this.state.gradientDepth--; + }; + Engine2.prototype.startScope = function(name) { + var scopeInfo = { + track: [], + name: "unnamed scope", + id: this.state.nextScopeId++ + }; + if (name) { + scopeInfo.name = name; + } + this.state.scopeStack.push(scopeInfo); + this.state.activeScope = scopeInfo; + }; + Engine2.prototype.endScope = function(result) { + var _this = this; + var tensorsToTrackInParent = getTensorsInContainer(result); + var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function(t) { + return t.id; + })); + for (var i = 0; i < this.state.activeScope.track.length; i++) { + var tensor2 = this.state.activeScope.track[i]; + if (!tensor2.kept && !tensorsToTrackInParentSet.has(tensor2.id)) { + tensor2.dispose(); + } + } + var oldScope = this.state.scopeStack.pop(); + this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1]; + tensorsToTrackInParent.forEach(function(tensor3) { + if (!tensor3.kept && tensor3.scopeId === oldScope.id) { + _this.track(tensor3); + } }); - }, n.prototype.gradients = function(t, e, r, i) { - var a = this; - if (i === void 0 && (i = false), E(e.length > 0, function() { + }; + Engine2.prototype.gradients = function(f, xs, dy, allowNoGradients) { + var _this = this; + if (allowNoGradients === void 0) { + allowNoGradients = false; + } + assert(xs.length > 0, function() { return "gradients() received an empty list of xs."; - }), r != null && r.dtype !== "float32") - throw new Error("dy must have 'float32' dtype, but has '" + r.dtype + "'"); - var s = this.scopedRun(function() { - return a.startTape(); - }, function() { - return a.endTape(); - }, function() { - return a.tidy("forward", t); }); - E(s instanceof K, function() { + if (dy != null && dy.dtype !== "float32") { + throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'"); + } + var y = this.scopedRun(function() { + return _this.startTape(); + }, function() { + return _this.endTape(); + }, function() { + return _this.tidy("forward", f); + }); + assert(y instanceof Tensor, function() { return "The result y returned by f() must be a tensor."; }); - var o = VL(this.state.activeTape, e, s); - if (!i && o.length === 0 && e.length > 0) + var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); + if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y."); + } return this.tidy("backward", function() { - var c = {}; - c[s.id] = r == null ? eS(s.shape) : r, GL(c, o, function(u) { - return a.tidy(u); - }, tS); - var l = e.map(function(u) { - return c[u.id]; + var accumulatedGradientMap = {}; + accumulatedGradientMap[y.id] = dy == null ? ones(y.shape) : dy; + backpropagateGradients(accumulatedGradientMap, filteredTape, function(f2) { + return _this.tidy(f2); + }, add); + var grads2 = xs.map(function(x) { + return accumulatedGradientMap[x.id]; }); - return a.state.gradientDepth === 0 && (a.state.activeTape.forEach(function(u) { - for (var h = 0, d = u.saved; h < d.length; h++) { - var p = d[h]; - p.dispose(); - } - }), a.state.activeTape = null), {value: s, grads: l}; + if (_this.state.gradientDepth === 0) { + _this.state.activeTape.forEach(function(node) { + for (var _i2 = 0, _a = node.saved; _i2 < _a.length; _i2++) { + var tensor2 = _a[_i2]; + tensor2.dispose(); + } + }); + _this.state.activeTape = null; + } + return {value: y, grads: grads2}; }); - }, n.prototype.customGrad = function(t) { - var e = this; - return E(cr(t), function() { + }; + Engine2.prototype.customGrad = function(f) { + var _this = this; + assert(isFunction(f), function() { return "The f passed in customGrad(f) must be a function."; - }), function() { - for (var r = [], i = 0; i < arguments.length; i++) - r[i] = arguments[i]; - E(r.every(function(o) { - return o instanceof K; + }); + return function() { + var inputs = []; + for (var _i2 = 0; _i2 < arguments.length; _i2++) { + inputs[_i2] = arguments[_i2]; + } + assert(inputs.every(function(t) { + return t instanceof Tensor; }), function() { return "The args passed in customGrad(f)(x1, x2,...) must all be tensors"; }); - var a, s = {}; - return r.forEach(function(o, c) { - s[c] = o; - }), e.runKernelFunc(function(o, c) { - return a = t.apply(void 0, r.concat([c])), E(a.value instanceof K, function() { + var res; + var inputMap = {}; + inputs.forEach(function(input, i) { + inputMap[i] = input; + }); + return _this.runKernelFunc(function(_, save) { + res = f.apply(void 0, inputs.concat([save])); + assert(res.value instanceof Tensor, function() { return "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"; - }), E(cr(a.gradFunc), function() { + }); + assert(isFunction(res.gradFunc), function() { return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."; - }), a.value; - }, s, function(o, c) { - var l = a.gradFunc(o, c), u = Array.isArray(l) ? l : [l]; - E(u.length === r.length, function() { + }); + return res.value; + }, inputMap, function(dy, saved) { + var gradRes = res.gradFunc(dy, saved); + var grads2 = Array.isArray(gradRes) ? gradRes : [gradRes]; + assert(grads2.length === inputs.length, function() { return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."; - }), E(u.every(function(d) { - return d instanceof K; + }); + assert(grads2.every(function(t) { + return t instanceof Tensor; }), function() { return "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors."; }); - var h = {}; - return u.forEach(function(d, p) { - h[p] = function() { - return d; + var gradMap = {}; + grads2.forEach(function(grad2, i) { + gradMap[i] = function() { + return grad2; }; - }), h; + }); + return gradMap; }); }; - }, n.prototype.readSync = function(t) { - var e = this.state.tensorInfo.get(t); - return e.backend.readSync(t); - }, n.prototype.read = function(t) { - var e = this.state.tensorInfo.get(t); - return e.backend.read(t); - }, n.prototype.time = function(t) { - return pe(this, void 0, void 0, function() { - var e, r; - return fe(this, function(i) { - switch (i.label) { + }; + Engine2.prototype.readSync = function(dataId) { + var info = this.state.tensorInfo.get(dataId); + return info.backend.readSync(dataId); + }; + Engine2.prototype.read = function(dataId) { + var info = this.state.tensorInfo.get(dataId); + return info.backend.read(dataId); + }; + Engine2.prototype.time = function(query) { + return __awaiter(this, void 0, void 0, function() { + var start, timingInfo; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - return e = pu(), [4, this.backend.time(t)]; + start = now2(); + return [4, this.backend.time(query)]; case 1: - return r = i.sent(), r.wallMs = pu() - e, [2, r]; + timingInfo = _a.sent(); + timingInfo.wallMs = now2() - start; + return [2, timingInfo]; } }); }); - }, n.prototype.track = function(t) { - return this.state.activeScope != null && (t.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(t)), t; - }, Object.defineProperty(n.prototype, "registeredVariables", {get: function() { - return this.state.registeredVariables; - }, enumerable: true, configurable: true}), n.prototype.reset = function() { - this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new _m(); - for (var t in this.registry) - this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t]; - this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null; - }, n.nextTensorId = 0, n.nextVariableId = 0, n; + }; + Engine2.prototype.track = function(result) { + if (this.state.activeScope != null) { + result.scopeId = this.state.activeScope.id; + this.state.activeScope.track.push(result); + } + return result; + }; + Object.defineProperty(Engine2.prototype, "registeredVariables", { + get: function() { + return this.state.registeredVariables; + }, + enumerable: true, + configurable: true + }); + Engine2.prototype.reset = function() { + this.pendingBackendInitId++; + this.state.dispose(); + this.ENV.reset(); + this.state = new EngineState(); + for (var backendName in this.registry) { + this.disposeRegisteredKernels(backendName); + this.registry[backendName].dispose(); + delete this.registry[backendName]; + } + this.backendName = null; + this.backendInstance = null; + this.pendingBackendInit = null; + }; + Engine2.nextTensorId = 0; + Engine2.nextVariableId = 0; + return Engine2; }(); - function eS(n) { - var t = yc(pt(n), "float32"); - return z.makeTensor(t, n, "float32"); + function ones(shape) { + var values = makeOnesTypedArray(sizeFromShape(shape), "float32"); + return ENGINE.makeTensor(values, shape, "float32"); } - function Cm() { - var n = Sf(); - if (n._tfengine == null) { - var t = new Lf(n); - n._tfengine = new nS(t); + function getOrMakeEngine() { + var ns = getGlobalNamespace(); + if (ns._tfengine == null) { + var environment = new Environment(ns); + ns._tfengine = new Engine(environment); } - return RL(n._tfengine.ENV), KL(function() { - return n._tfengine; - }), n._tfengine; + setEnvironmentGlobal(ns._tfengine.ENV); + setTensorTracker(function() { + return ns._tfengine; + }); + return ns._tfengine; } - var z = Cm(); - function tS(n, t) { - var e = {a: n, b: t}; - return z.runKernelFunc(function(r, i) { - var a = r.add(n, t); - return i([n, t]), a; - }, e, null, Cs); + var ENGINE = getOrMakeEngine(); + function add(a, b) { + var inputs = {a, b}; + return ENGINE.runKernelFunc(function(backend2, save) { + var res = backend2.add(a, b); + save([a, b]); + return res; + }, inputs, null, Add); } - function rS() { - return typeof navigator != "undefined" && navigator != null; + /** + * @license + * Copyright 2017 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function _isNavigatorDefined() { + return typeof navigator !== "undefined" && navigator != null; } - function iS() { - if (rS()) { - var n = navigator.userAgent || navigator.vendor || window.opera; - return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(n) || /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(n.substr(0, 4)); + function isMobile() { + if (_isNavigatorDefined()) { + var a = navigator.userAgent || navigator.vendor || window.opera; + return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(a) || /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(a.substr(0, 4)); } return false; } - function Rm() { - return typeof window != "undefined" && window.document != null || typeof WorkerGlobalScope != "undefined"; + function isBrowser() { + return typeof window !== "undefined" && window.document != null || typeof WorkerGlobalScope !== "undefined"; } - var aS = {__proto__: null, isMobile: iS, isBrowser: Rm}; - var Yn = Ge(); - Yn.registerFlag("DEBUG", function() { + var device_util = { + __proto__: null, + isMobile, + isBrowser + }; + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ENV = env(); + ENV.registerFlag("DEBUG", function() { return false; - }, function(n) { - n && console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance."); + }, function(debugValue) { + if (debugValue) { + console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance."); + } }); - Yn.registerFlag("IS_BROWSER", function() { - return Rm(); + ENV.registerFlag("IS_BROWSER", function() { + return isBrowser(); }); - Yn.registerFlag("IS_NODE", function() { - return typeof process != "undefined" && typeof process.versions != "undefined" && typeof process.versions.node != "undefined"; + ENV.registerFlag("IS_NODE", function() { + return typeof process !== "undefined" && typeof process.versions !== "undefined" && typeof process.versions.node !== "undefined"; }); - Yn.registerFlag("IS_CHROME", function() { - return typeof navigator != "undefined" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor); + ENV.registerFlag("IS_CHROME", function() { + return typeof navigator !== "undefined" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor); }); - Yn.registerFlag("PROD", function() { + ENV.registerFlag("PROD", function() { return false; }); - Yn.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", function() { - return Yn.getBool("DEBUG"); + ENV.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", function() { + return ENV.getBool("DEBUG"); }); - Yn.registerFlag("DEPRECATION_WARNINGS_ENABLED", function() { + ENV.registerFlag("DEPRECATION_WARNINGS_ENABLED", function() { return true; }); - Yn.registerFlag("IS_TEST", function() { + ENV.registerFlag("IS_TEST", function() { return false; }); - function Rn(n, t) { - var e = n; - if (Ft(n)) - return t === "string" ? [] : [n.length]; - if (!Array.isArray(n)) + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function inferShape(val, dtype) { + var firstElem = val; + if (isTypedArray(val)) { + return dtype === "string" ? [] : [val.length]; + } + if (!Array.isArray(val)) { return []; - for (var r = []; Array.isArray(e) || Ft(e) && t !== "string"; ) - r.push(e.length), e = e[0]; - return Array.isArray(n) && Ge().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") && Om(n, r, []), r; + } + var shape = []; + while (Array.isArray(firstElem) || isTypedArray(firstElem) && dtype !== "string") { + shape.push(firstElem.length); + firstElem = firstElem[0]; + } + if (Array.isArray(val) && env().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY")) { + deepAssertShapeConsistency(val, shape, []); + } + return shape; } - function Om(n, t, e) { - if (e = e || [], !Array.isArray(n) && !Ft(n)) { - E(t.length === 0, function() { - return "Element arr[" + e.join("][") + "] is a primitive, " + ("but should be an array/TypedArray of " + t[0] + " elements"); + function deepAssertShapeConsistency(val, shape, indices) { + indices = indices || []; + if (!Array.isArray(val) && !isTypedArray(val)) { + assert(shape.length === 0, function() { + return "Element arr[" + indices.join("][") + "] is a primitive, " + ("but should be an array/TypedArray of " + shape[0] + " elements"); }); return; } - E(t.length > 0, function() { - return "Element arr[" + e.join("][") + "] should be a primitive, " + ("but is an array of " + n.length + " elements"); - }), E(n.length === t[0], function() { - return "Element arr[" + e.join("][") + "] should have " + t[0] + " " + ("elements, but has " + n.length + " elements"); + assert(shape.length > 0, function() { + return "Element arr[" + indices.join("][") + "] should be a primitive, " + ("but is an array of " + val.length + " elements"); }); - for (var r = t.slice(1), i = 0; i < n.length; ++i) - Om(n[i], r, e.concat(i)); - } - function Em(n, t, e, r) { - if (n == null) - return; - if (n !== "numeric" && n !== t || n === "numeric" && t === "string") - throw new Error("Argument '" + e + "' passed to '" + r + "' must " + ("be " + n + " tensor, but got " + t + " tensor")); - } - function R(n, t, e, r) { - if (r === void 0 && (r = "numeric"), n instanceof K) - return Em(r, n.dtype, t, e), n; - var i = Ns(n); - if (i !== "string" && ["bool", "int32", "float32"].indexOf(r) >= 0 && (i = r), Em(r, i, t, e), n == null || !Ft(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string") { - var a = n == null ? "null" : n.constructor.name; - throw new Error("Argument '" + t + "' passed to '" + e + "' must be a " + ("Tensor or TensorLike, but got '" + a + "'")); + assert(val.length === shape[0], function() { + return "Element arr[" + indices.join("][") + "] should have " + shape[0] + " " + ("elements, but has " + val.length + " elements"); + }); + var subShape = shape.slice(1); + for (var i = 0; i < val.length; ++i) { + deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); } - var s = Rn(n, i); - !Ft(n) && !Array.isArray(n) && (n = [n]); - var o = true, c = i !== "string" ? Es(n, i) : Br(n, [], o); - return z.makeTensor(c, s, i); } - function Sa(n, t, e, r) { - if (r === void 0 && (r = "numeric"), !Array.isArray(n)) - throw new Error("Argument " + t + " passed to " + e + " must be a `Tensor[]` or `TensorLike[]`"); - var i = n; - return i.map(function(a, s) { - return R(a, t + "[" + s + "]", e); - }, r); + function assertDtype(expectedDtype, actualDType, argName, functionName) { + if (expectedDtype == null) { + return; + } + if (expectedDtype !== "numeric" && expectedDtype !== actualDType || expectedDtype === "numeric" && actualDType === "string") { + throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor")); + } } - var Dm = "__op"; - function U(n) { - var t = Object.keys(n); - if (t.length !== 1) - throw new Error("Please provide an object with a single key (operation name) mapping to a function. Got an object with " + (t.length + " keys.")); - var e = t[0], r = n[e]; - e.endsWith("_") && (e = e.substring(0, e.length - 1)), e = e + Dm; - var i = function() { - for (var a = [], s = 0; s < arguments.length; s++) - a[s] = arguments[s]; - z.startScope(e); + function convertToTensor(x, argName, functionName, parseAsDtype) { + if (parseAsDtype === void 0) { + parseAsDtype = "numeric"; + } + if (x instanceof Tensor) { + assertDtype(parseAsDtype, x.dtype, argName, functionName); + return x; + } + var inferredDtype = inferDtype(x); + if (inferredDtype !== "string" && ["bool", "int32", "float32"].indexOf(parseAsDtype) >= 0) { + inferredDtype = parseAsDtype; + } + assertDtype(parseAsDtype, inferredDtype, argName, functionName); + if (x == null || !isTypedArray(x) && !Array.isArray(x) && typeof x !== "number" && typeof x !== "boolean" && typeof x !== "string") { + var type = x == null ? "null" : x.constructor.name; + throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + ("Tensor or TensorLike, but got '" + type + "'")); + } + var inferredShape = inferShape(x, inferredDtype); + if (!isTypedArray(x) && !Array.isArray(x)) { + x = [x]; + } + var skipTypedArray = true; + var values = inferredDtype !== "string" ? toTypedArray(x, inferredDtype) : flatten(x, [], skipTypedArray); + return ENGINE.makeTensor(values, inferredShape, inferredDtype); + } + function convertToTensorArray(arg, argName, functionName, parseAsDtype) { + if (parseAsDtype === void 0) { + parseAsDtype = "numeric"; + } + if (!Array.isArray(arg)) { + throw new Error("Argument " + argName + " passed to " + functionName + " must be a `Tensor[]` or `TensorLike[]`"); + } + var tensors = arg; + return tensors.map(function(t, i) { + return convertToTensor(t, argName + "[" + i + "]", functionName); + }, parseAsDtype); + } + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var OP_SCOPE_SUFFIX = "__op"; + function op(f) { + var keys = Object.keys(f); + if (keys.length !== 1) { + throw new Error("Please provide an object with a single key (operation name) mapping to a function. Got an object with " + (keys.length + " keys.")); + } + var opName = keys[0]; + var fn = f[opName]; + if (opName.endsWith("_")) { + opName = opName.substring(0, opName.length - 1); + } + opName = opName + OP_SCOPE_SUFFIX; + var f2 = function() { + var args = []; + for (var _i2 = 0; _i2 < arguments.length; _i2++) { + args[_i2] = arguments[_i2]; + } + ENGINE.startScope(opName); try { - var o = r.apply(void 0, a); - return wc(o) && console.error("Cannot return a Promise inside of tidy."), z.endScope(o), o; - } catch (c) { - throw z.endScope(null), c; + var result = fn.apply(void 0, args); + if (isPromise(result)) { + console.error("Cannot return a Promise inside of tidy."); + } + ENGINE.endScope(result); + return result; + } catch (ex) { + ENGINE.endScope(null); + throw ex; } }; - return Object.defineProperty(i, "name", {value: e, configurable: true}), i; + Object.defineProperty(f2, "name", {value: opName, configurable: true}); + return f2; } - function sS(n, t) { - var e = R(n, "real", "complex"), r = R(t, "imag", "complex"); - Pe(e.shape, r.shape, "real and imag shapes, " + e.shape + " and " + r.shape + ", must match in call to tf.complex()."); - var i = function(s) { - return s.complex(e, r); - }, a = {real: e, imag: r}; - return z.runKernelFunc(i, a, null, Cf); + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function complex_(real2, imag2) { + var $real = convertToTensor(real2, "real", "complex"); + var $imag = convertToTensor(imag2, "imag", "complex"); + assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", must match in call to tf.complex()."); + var forward = function(backend2) { + return backend2.complex($real, $imag); + }; + var inputs = {real: $real, imag: $imag}; + return ENGINE.runKernelFunc(forward, inputs, null, Complex); } - var lr = U({complex_: sS}); - function ur(n, t, e, r) { - if (r == null && (r = Ns(n)), r === "complex64") + var complex = op({complex_}); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function makeTensor(values, shape, inferredShape, dtype) { + if (dtype == null) { + dtype = inferDtype(values); + } + if (dtype === "complex64") { throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag)."); - if (!Ft(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string") + } + if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== "number" && typeof values !== "boolean" && typeof values !== "string") { throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray"); - if (t != null) { - vc(t); - var i = pt(t), a = pt(e); - E(i === a, function() { - return "Based on the provided shape, [" + t + "], the tensor should have " + (i + " values but has " + a); + } + if (shape != null) { + assertNonNegativeIntegerDimensions(shape); + var providedSize_1 = sizeFromShape(shape); + var inferredSize_1 = sizeFromShape(inferredShape); + assert(providedSize_1 === inferredSize_1, function() { + return "Based on the provided shape, [" + shape + "], the tensor should have " + (providedSize_1 + " values but has " + inferredSize_1); }); - for (var s = 0; s < e.length; ++s) { - var o = e[s], c = s === e.length - 1 ? o !== pt(t.slice(s)) : true; - E(e[s] === t[s] || !c, function() { - return "Error creating a new Tensor. Inferred shape " + ("(" + e + ") does not match the provided ") + ("shape (" + t + "). "); + for (var i = 0; i < inferredShape.length; ++i) { + var inferred = inferredShape[i]; + var flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true; + assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function() { + return "Error creating a new Tensor. Inferred shape " + ("(" + inferredShape + ") does not match the provided ") + ("shape (" + shape + "). "); }); } } - return !Ft(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = r !== "string" ? Es(n, r) : Br(n, [], true), z.makeTensor(n, t, r); + if (!isTypedArray(values) && !Array.isArray(values)) { + values = [values]; + } + shape = shape || inferredShape; + values = dtype !== "string" ? toTypedArray(values, dtype) : flatten(values, [], true); + return ENGINE.makeTensor(values, shape, dtype); } - function hr(n, t, e) { - var r = Rn(n, e); - return ur(n, t, r, e); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function tensor(values, shape, dtype) { + var inferredShape = inferShape(values, dtype); + return makeTensor(values, shape, inferredShape, dtype); } - var xu = {float32: 4, float16: 2, int32: 4, uint16: 2, uint8: 1, bool: 1, complex64: 8}; - var Ws = 4; - function cS(n, t) { - return pe(this, void 0, void 0, function() { - var e, r, i, a, s, o, c = this; - return fe(this, function(l) { - switch (l.label) { + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DTYPE_VALUE_SIZE_MAP = { + float32: 4, + float16: 2, + int32: 4, + uint16: 2, + uint8: 1, + bool: 1, + complex64: 8 + }; + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var NUM_BYTES_STRING_LENGTH = 4; + function encodeWeights(tensors, group) { + return __awaiter(this, void 0, void 0, function() { + var specs, dataPromises, names, _loop_1, i, tensorValues; + var _this = this; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - for (e = [], r = [], i = Array.isArray(n) ? n.map(function(u) { - return u.name; - }) : Object.keys(n), a = function(u) { - var h = i[u], d = Array.isArray(n) ? n[u].tensor : n[h]; - if (d.dtype !== "float32" && d.dtype !== "int32" && d.dtype !== "bool" && d.dtype !== "string" && d.dtype !== "complex64") - throw new Error("Unsupported dtype in weight '" + h + "': " + d.dtype); - var p = {name: h, shape: d.shape, dtype: d.dtype}; - if (d.dtype === "string") { - var f = new Promise(function(m) { - return pe(c, void 0, void 0, function() { - var g, y, w, b, L, x, N; - return fe(this, function(I) { - switch (I.label) { + specs = []; + dataPromises = []; + names = Array.isArray(tensors) ? tensors.map(function(tensor2) { + return tensor2.name; + }) : Object.keys(tensors); + _loop_1 = function(i2) { + var name_1 = names[i2]; + var t = Array.isArray(tensors) ? tensors[i2].tensor : tensors[name_1]; + if (t.dtype !== "float32" && t.dtype !== "int32" && t.dtype !== "bool" && t.dtype !== "string" && t.dtype !== "complex64") { + throw new Error("Unsupported dtype in weight '" + name_1 + "': " + t.dtype); + } + var spec = {name: name_1, shape: t.shape, dtype: t.dtype}; + if (t.dtype === "string") { + var utf8bytes = new Promise(function(resolve) { + return __awaiter(_this, void 0, void 0, function() { + var vals, totalNumBytes, bytes, offset, i_1, val, bytesOfLength; + return __generator(this, function(_a2) { + switch (_a2.label) { case 0: - return [4, d.bytes()]; + return [4, t.bytes()]; case 1: - for (g = I.sent(), y = g.reduce(function(C, O) { - return C + O.length; - }, 0) + Ws * g.length, w = new Uint8Array(y), b = 0, L = 0; L < g.length; L++) - x = g[L], N = new Uint8Array(new Uint32Array([x.length]).buffer), w.set(N, b), b += Ws, w.set(x, b), b += x.length; - return m(w), [2]; + vals = _a2.sent(); + totalNumBytes = vals.reduce(function(p, c) { + return p + c.length; + }, 0) + NUM_BYTES_STRING_LENGTH * vals.length; + bytes = new Uint8Array(totalNumBytes); + offset = 0; + for (i_1 = 0; i_1 < vals.length; i_1++) { + val = vals[i_1]; + bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer); + bytes.set(bytesOfLength, offset); + offset += NUM_BYTES_STRING_LENGTH; + bytes.set(val, offset); + offset += val.length; + } + resolve(bytes); + return [2]; } }); }); }); - r.push(f); - } else - r.push(d.data()); - t != null && (p.group = t), e.push(p); - }, s = 0; s < i.length; ++s) - a(s); - return [4, Promise.all(r)]; + dataPromises.push(utf8bytes); + } else { + dataPromises.push(t.data()); + } + if (group != null) { + spec.group = group; + } + specs.push(spec); + }; + for (i = 0; i < names.length; ++i) { + _loop_1(i); + } + return [4, Promise.all(dataPromises)]; case 1: - return o = l.sent(), [2, {data: oS(o), specs: e}]; + tensorValues = _a.sent(); + return [2, {data: concatenateTypedArrays(tensorValues), specs}]; } }); }); } - function km(n, t) { - for (var e = {}, r, i = 0, a = 0, s = t; a < s.length; a++) { - var o = s[a], c = o.name, l = o.dtype, u = o.shape, h = pt(u), d = void 0; - if ("quantization" in o) { - var p = o.quantization; - if (p.dtype === "uint8" || p.dtype === "uint16") { - if (!("min" in p && "scale" in p)) - throw new Error("Weight " + o.name + " with quantization " + p.dtype + " doesn't have corresponding metadata min and scale."); - } else if (p.dtype === "float16") { - if (l !== "float32") - throw new Error("Weight " + o.name + " is quantized with " + p.dtype + " " + ("which only supports weights of type float32 not " + l + ".")); - } else - throw new Error("Weight " + o.name + " has unknown " + ("quantization dtype " + p.dtype + ". ") + "Supported quantization dtypes are: 'uint8', 'uint16', and 'float16'."); - var f = xu[p.dtype], m = n.slice(i, i + h * f), g = p.dtype === "uint8" ? new Uint8Array(m) : new Uint16Array(m); - if (l === "float32") - if (p.dtype === "uint8" || p.dtype === "uint16") { - d = new Float32Array(g.length); - for (var y = 0; y < g.length; y++) { - var w = g[y]; - d[y] = w * p.scale + p.min; - } - } else if (p.dtype === "float16") - r === void 0 && (r = lS()), d = r(g); - else - throw new Error("Unsupported quantization type " + p.dtype + " for weight type float32."); - else if (l === "int32") { - if (p.dtype !== "uint8" && p.dtype !== "uint16") - throw new Error("Unsupported quantization type " + p.dtype + " for weight type int32."); - d = new Int32Array(g.length); - for (var y = 0; y < g.length; y++) { - var w = g[y]; - d[y] = Math.round(w * p.scale + p.min); + function decodeWeights(buffer2, specs) { + var out = {}; + var float16Decode; + var offset = 0; + for (var _i2 = 0, specs_1 = specs; _i2 < specs_1.length; _i2++) { + var spec = specs_1[_i2]; + var name_2 = spec.name; + var dtype = spec.dtype; + var shape = spec.shape; + var size = sizeFromShape(shape); + var values = void 0; + if ("quantization" in spec) { + var quantization = spec.quantization; + if (quantization.dtype === "uint8" || quantization.dtype === "uint16") { + if (!("min" in quantization && "scale" in quantization)) { + throw new Error("Weight " + spec.name + " with quantization " + quantization.dtype + " doesn't have corresponding metadata min and scale."); } - } else - throw new Error("Unsupported dtype in weight '" + c + "': " + l); - i += h * f; - } else if (l === "string") { - var b = pt(o.shape); - d = []; - for (var y = 0; y < b; y++) { - var L = new Uint32Array(n.slice(i, i + Ws))[0]; - i += Ws; - var x = new Uint8Array(n.slice(i, i + L)); - d.push(x), i += L; + } else if (quantization.dtype === "float16") { + if (dtype !== "float32") { + throw new Error("Weight " + spec.name + " is quantized with " + quantization.dtype + " " + ("which only supports weights of type float32 not " + dtype + ".")); + } + } else { + throw new Error("Weight " + spec.name + " has unknown " + ("quantization dtype " + quantization.dtype + ". ") + "Supported quantization dtypes are: 'uint8', 'uint16', and 'float16'."); + } + var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; + var byteBuffer = buffer2.slice(offset, offset + size * quantizationSizeFactor); + var quantizedArray = quantization.dtype === "uint8" ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); + if (dtype === "float32") { + if (quantization.dtype === "uint8" || quantization.dtype === "uint16") { + values = new Float32Array(quantizedArray.length); + for (var i = 0; i < quantizedArray.length; i++) { + var v = quantizedArray[i]; + values[i] = v * quantization.scale + quantization.min; + } + } else if (quantization.dtype === "float16") { + if (float16Decode === void 0) { + float16Decode = getFloat16Decoder(); + } + values = float16Decode(quantizedArray); + } else { + throw new Error("Unsupported quantization type " + quantization.dtype + " for weight type float32."); + } + } else if (dtype === "int32") { + if (quantization.dtype !== "uint8" && quantization.dtype !== "uint16") { + throw new Error("Unsupported quantization type " + quantization.dtype + " for weight type int32."); + } + values = new Int32Array(quantizedArray.length); + for (var i = 0; i < quantizedArray.length; i++) { + var v = quantizedArray[i]; + values[i] = Math.round(v * quantization.scale + quantization.min); + } + } else { + throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype); + } + offset += size * quantizationSizeFactor; + } else if (dtype === "string") { + var size_1 = sizeFromShape(spec.shape); + values = []; + for (var i = 0; i < size_1; i++) { + var byteLength = new Uint32Array(buffer2.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + offset += NUM_BYTES_STRING_LENGTH; + var bytes = new Uint8Array(buffer2.slice(offset, offset + byteLength)); + values.push(bytes); + offset += byteLength; } } else { - var N = xu[l], m = n.slice(i, i + h * N); - if (l === "float32") - d = new Float32Array(m); - else if (l === "int32") - d = new Int32Array(m); - else if (l === "bool") - d = new Uint8Array(m); - else if (l === "complex64") { - d = new Float32Array(m); - for (var I = new Float32Array(d.length / 2), C = new Float32Array(d.length / 2), y = 0; y < I.length; y++) - I[y] = d[y * 2], C[y] = d[y * 2 + 1]; - var O = hr(I, u, "float32"), D = hr(C, u, "float32"); - e[c] = lr(O, D), O.dispose(), D.dispose(); - } else - throw new Error("Unsupported dtype in weight '" + c + "': " + l); - i += h * N; + var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; + var byteBuffer = buffer2.slice(offset, offset + size * dtypeFactor); + if (dtype === "float32") { + values = new Float32Array(byteBuffer); + } else if (dtype === "int32") { + values = new Int32Array(byteBuffer); + } else if (dtype === "bool") { + values = new Uint8Array(byteBuffer); + } else if (dtype === "complex64") { + values = new Float32Array(byteBuffer); + var real2 = new Float32Array(values.length / 2); + var image3 = new Float32Array(values.length / 2); + for (var i = 0; i < real2.length; i++) { + real2[i] = values[i * 2]; + image3[i] = values[i * 2 + 1]; + } + var realTensor = tensor(real2, shape, "float32"); + var imageTensor = tensor(image3, shape, "float32"); + out[name_2] = complex(realTensor, imageTensor); + realTensor.dispose(); + imageTensor.dispose(); + } else { + throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype); + } + offset += size * dtypeFactor; + } + if (dtype !== "complex64") { + out[name_2] = tensor(values, shape, dtype); } - l !== "complex64" && (e[c] = hr(d, u, l)); } - return e; + return out; } - function oS(n) { - if (n === null) - throw new Error("Invalid input value: " + JSON.stringify(n)); - var t = 0, e = []; - n.forEach(function(a) { - if (t += a.byteLength, e.push(a.byteLength === a.buffer.byteLength ? a : new a.constructor(a)), !(a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array)) - throw new Error("Unsupported TypedArray subtype: " + a.constructor.name); - }); - var r = new Uint8Array(t), i = 0; - return e.forEach(function(a) { - r.set(new Uint8Array(a.buffer), i), i += a.byteLength; - }), r.buffer; - } - var Lu = typeof Buffer != "undefined" && (typeof Blob == "undefined" || typeof atob == "undefined" || typeof btoa == "undefined"); - function Fm(n) { - return Lu ? Buffer.byteLength(n) : new Blob([n]).size; - } - function uS(n) { - if (Lu) - return Buffer.from(n).toString("base64"); - for (var t = new Uint8Array(n), e = "", r = 0, i = t.length; r < i; r++) - e += String.fromCharCode(t[r]); - return btoa(e); - } - function hS(n) { - if (Lu) { - var t = Buffer.from(n, "base64"); - return t.buffer.slice(t.byteOffset, t.byteOffset + t.byteLength); + function concatenateTypedArrays(xs) { + if (xs === null) { + throw new Error("Invalid input value: " + JSON.stringify(xs)); } - for (var e = atob(n), r = new Uint8Array(e.length), i = 0; i < e.length; ++i) - r.set([e.charCodeAt(i)], i); - return r.buffer; - } - function Su(n) { - if (n.length === 1) - return n[0]; - var t = 0; - n.forEach(function(i) { - t += i.byteLength; + var totalByteLength = 0; + var normalizedXs = []; + xs.forEach(function(x) { + totalByteLength += x.byteLength; + normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : new x.constructor(x)); + if (!(x instanceof Float32Array || x instanceof Int32Array || x instanceof Uint8Array)) { + throw new Error("Unsupported TypedArray subtype: " + x.constructor.name); + } }); - var e = new Uint8Array(t), r = 0; - return n.forEach(function(i) { - e.set(new Uint8Array(i), r), r += i.byteLength; - }), e.buffer; + var y = new Uint8Array(totalByteLength); + var offset = 0; + normalizedXs.forEach(function(x) { + y.set(new Uint8Array(x.buffer), offset); + offset += x.byteLength; + }); + return y.buffer; } - function Wm(n) { - var t = "/"; - for (n = n.trim(); n.endsWith(t); ) - n = n.slice(0, n.length - 1); - var e = n.split(t); - return e[e.length - 1]; + var useNodeBuffer = typeof Buffer !== "undefined" && (typeof Blob === "undefined" || typeof atob === "undefined" || typeof btoa === "undefined"); + function stringByteLength(str2) { + if (useNodeBuffer) { + return Buffer.byteLength(str2); + } + return new Blob([str2]).size; } - function Ia(n) { - if (n.modelTopology instanceof ArrayBuffer) + function arrayBufferToBase64String(buffer2) { + if (useNodeBuffer) { + return Buffer.from(buffer2).toString("base64"); + } + var buf = new Uint8Array(buffer2); + var s = ""; + for (var i = 0, l = buf.length; i < l; i++) { + s += String.fromCharCode(buf[i]); + } + return btoa(s); + } + function base64StringToArrayBuffer(str2) { + if (useNodeBuffer) { + var buf = Buffer.from(str2, "base64"); + return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); + } + var s = atob(str2); + var buffer2 = new Uint8Array(s.length); + for (var i = 0; i < s.length; ++i) { + buffer2.set([s.charCodeAt(i)], i); + } + return buffer2.buffer; + } + function concatenateArrayBuffers(buffers) { + if (buffers.length === 1) { + return buffers[0]; + } + var totalByteLength = 0; + buffers.forEach(function(buffer2) { + totalByteLength += buffer2.byteLength; + }); + var temp = new Uint8Array(totalByteLength); + var offset = 0; + buffers.forEach(function(buffer2) { + temp.set(new Uint8Array(buffer2), offset); + offset += buffer2.byteLength; + }); + return temp.buffer; + } + function basename(path) { + var SEPARATOR = "/"; + path = path.trim(); + while (path.endsWith(SEPARATOR)) { + path = path.slice(0, path.length - 1); + } + var items = path.split(SEPARATOR); + return items[items.length - 1]; + } + function getModelArtifactsInfoForJSON(modelArtifacts) { + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("Expected JSON model topology, received ArrayBuffer."); - return {dateSaved: new Date(), modelTopologyType: "JSON", modelTopologyBytes: n.modelTopology == null ? 0 : Fm(JSON.stringify(n.modelTopology)), weightSpecsBytes: n.weightSpecs == null ? 0 : Fm(JSON.stringify(n.weightSpecs)), weightDataBytes: n.weightData == null ? 0 : n.weightData.byteLength}; - } - function dS() { - var n = function(r) { - for (var i = r << 13, a = 0; (i & 8388608) === 0; ) - a -= 8388608, i <<= 1; - return i &= ~8388608, a += 947912704, i | a; - }, t = new Uint32Array(2048); - t[0] = 0; - for (var e = 1; e < 1024; e++) - t[e] = n(e); - for (var e = 1024; e < 2048; e++) - t[e] = 939524096 + (e - 1024 << 13); - return t; - } - function pS() { - var n = new Uint32Array(64); - n[0] = 0, n[31] = 1199570944, n[32] = 2147483648, n[63] = 3347054592; - for (var t = 1; t < 31; t++) - n[t] = t << 23; - for (var t = 33; t < 63; t++) - n[t] = 2147483648 + (t - 32 << 23); - return n; - } - function fS() { - for (var n = new Uint32Array(64), t = 0; t < 64; t++) - n[t] = 1024; - return n[0] = n[32] = 0, n; - } - function lS() { - var n = dS(), t = pS(), e = fS(); - return function(r) { - for (var i = new ArrayBuffer(4 * r.length), a = new Uint32Array(i), s = 0; s < r.length; s++) { - var o = r[s], c = n[e[o >> 10] + (o & 1023)] + t[o >> 10]; - a[s] = c; - } - return new Float32Array(i); + } + return { + dateSaved: new Date(), + modelTopologyType: "JSON", + modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)), + weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), + weightDataBytes: modelArtifacts.weightData == null ? 0 : modelArtifacts.weightData.byteLength }; } - var rn = function() { - function n() { - this.saveRouters = [], this.loadRouters = []; + function computeFloat16MantisaTable() { + var convertMantissa = function(i2) { + var m = i2 << 13; + var e = 0; + while ((m & 8388608) === 0) { + e -= 8388608; + m <<= 1; + } + m &= ~8388608; + e += 947912704; + return m | e; + }; + var mantisaTable = new Uint32Array(2048); + mantisaTable[0] = 0; + for (var i = 1; i < 1024; i++) { + mantisaTable[i] = convertMantissa(i); } - return n.getInstance = function() { - return n.instance == null && (n.instance = new n()), n.instance; - }, n.registerSaveRouter = function(t) { - n.getInstance().saveRouters.push(t); - }, n.registerLoadRouter = function(t) { - n.getInstance().loadRouters.push(t); - }, n.getSaveHandlers = function(t) { - return n.getHandlers(t, "save"); - }, n.getLoadHandlers = function(t, e) { - return n.getHandlers(t, "load", e); - }, n.getHandlers = function(t, e, r) { - var i = [], a = e === "load" ? n.getInstance().loadRouters : n.getInstance().saveRouters; - return a.forEach(function(s) { - var o = s(t, r); - o !== null && i.push(o); - }), i; - }, n; - }(), mS = function(n) { - return rn.registerSaveRouter(n); - }, gS = function(n) { - return rn.registerLoadRouter(n); - }, yS = function(n) { - return rn.getSaveHandlers(n); - }, vS = function(n, t) { - return rn.getLoadHandlers(n, t); + for (var i = 1024; i < 2048; i++) { + mantisaTable[i] = 939524096 + (i - 1024 << 13); + } + return mantisaTable; + } + function computeFloat16ExponentTable() { + var exponentTable = new Uint32Array(64); + exponentTable[0] = 0; + exponentTable[31] = 1199570944; + exponentTable[32] = 2147483648; + exponentTable[63] = 3347054592; + for (var i = 1; i < 31; i++) { + exponentTable[i] = i << 23; + } + for (var i = 33; i < 63; i++) { + exponentTable[i] = 2147483648 + (i - 32 << 23); + } + return exponentTable; + } + function computeFloat16OffsetTable() { + var offsetTable = new Uint32Array(64); + for (var i = 0; i < 64; i++) { + offsetTable[i] = 1024; + } + offsetTable[0] = offsetTable[32] = 0; + return offsetTable; + } + function getFloat16Decoder() { + var mantisaTable = computeFloat16MantisaTable(); + var exponentTable = computeFloat16ExponentTable(); + var offsetTable = computeFloat16OffsetTable(); + return function(quantizedArray) { + var buffer2 = new ArrayBuffer(4 * quantizedArray.length); + var bufferUint32View = new Uint32Array(buffer2); + for (var index = 0; index < quantizedArray.length; index++) { + var float16Bits = quantizedArray[index]; + var float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 1023)] + exponentTable[float16Bits >> 10]; + bufferUint32View[index] = float32Bits; + } + return new Float32Array(buffer2); + }; + } + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var IORouterRegistry = function() { + function IORouterRegistry2() { + this.saveRouters = []; + this.loadRouters = []; + } + IORouterRegistry2.getInstance = function() { + if (IORouterRegistry2.instance == null) { + IORouterRegistry2.instance = new IORouterRegistry2(); + } + return IORouterRegistry2.instance; + }; + IORouterRegistry2.registerSaveRouter = function(saveRouter) { + IORouterRegistry2.getInstance().saveRouters.push(saveRouter); + }; + IORouterRegistry2.registerLoadRouter = function(loadRouter) { + IORouterRegistry2.getInstance().loadRouters.push(loadRouter); + }; + IORouterRegistry2.getSaveHandlers = function(url) { + return IORouterRegistry2.getHandlers(url, "save"); + }; + IORouterRegistry2.getLoadHandlers = function(url, loadOptions) { + return IORouterRegistry2.getHandlers(url, "load", loadOptions); + }; + IORouterRegistry2.getHandlers = function(url, handlerType, loadOptions) { + var validHandlers = []; + var routers = handlerType === "load" ? IORouterRegistry2.getInstance().loadRouters : IORouterRegistry2.getInstance().saveRouters; + routers.forEach(function(router) { + var handler = router(url, loadOptions); + if (handler !== null) { + validHandlers.push(handler); + } + }); + return validHandlers; + }; + return IORouterRegistry2; + }(); + var registerSaveRouter = function(loudRouter) { + return IORouterRegistry.registerSaveRouter(loudRouter); }; - var Iu = "tensorflowjs", Au = 1, zr = "models_store", dr = "model_info_store"; - function Um() { - if (!Ge().getBool("IS_BROWSER")) + var registerLoadRouter = function(loudRouter) { + return IORouterRegistry.registerLoadRouter(loudRouter); + }; + var getSaveHandlers = function(url) { + return IORouterRegistry.getSaveHandlers(url); + }; + var getLoadHandlers = function(url, loadOptions) { + return IORouterRegistry.getLoadHandlers(url, loadOptions); + }; + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DATABASE_NAME = "tensorflowjs"; + var DATABASE_VERSION = 1; + var MODEL_STORE_NAME = "models_store"; + var INFO_STORE_NAME = "model_info_store"; + function getIndexedDBFactory() { + if (!env().getBool("IS_BROWSER")) { throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser."); - var n = typeof window == "undefined" ? self : window, t = n.indexedDB || n.mozIndexedDB || n.webkitIndexedDB || n.msIndexedDB || n.shimIndexedDB; - if (t == null) - throw new Error("The current browser does not appear to support IndexedDB."); - return t; - } - function Tu(n) { - var t = n.result; - t.createObjectStore(zr, {keyPath: "modelPath"}), t.createObjectStore(dr, {keyPath: "modelPath"}); - } - var Ai = function() { - function n(t) { - if (this.indexedDB = Um(), t == null || !t) - throw new Error("For IndexedDB, modelPath must not be null, undefined or empty."); - this.modelPath = t; } - return n.prototype.save = function(t) { - return pe(this, void 0, void 0, function() { - return fe(this, function(e) { - if (t.modelTopology instanceof ArrayBuffer) + var theWindow = typeof window === "undefined" ? self : window; + var factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB; + if (factory == null) { + throw new Error("The current browser does not appear to support IndexedDB."); + } + return factory; + } + function setUpDatabase(openRequest) { + var db = openRequest.result; + db.createObjectStore(MODEL_STORE_NAME, {keyPath: "modelPath"}); + db.createObjectStore(INFO_STORE_NAME, {keyPath: "modelPath"}); + } + var BrowserIndexedDB = function() { + function BrowserIndexedDB2(modelPath) { + this.indexedDB = getIndexedDBFactory(); + if (modelPath == null || !modelPath) { + throw new Error("For IndexedDB, modelPath must not be null, undefined or empty."); + } + this.modelPath = modelPath; + } + BrowserIndexedDB2.prototype.save = function(modelArtifacts) { + return __awaiter(this, void 0, void 0, function() { + return __generator(this, function(_a) { + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet."); - return [2, this.databaseAction(this.modelPath, t)]; + } + return [2, this.databaseAction(this.modelPath, modelArtifacts)]; }); }); - }, n.prototype.load = function() { - return pe(this, void 0, void 0, function() { - return fe(this, function(t) { + }; + BrowserIndexedDB2.prototype.load = function() { + return __awaiter(this, void 0, void 0, function() { + return __generator(this, function(_a) { return [2, this.databaseAction(this.modelPath)]; }); }); - }, n.prototype.databaseAction = function(t, e) { - var r = this; - return new Promise(function(i, a) { - var s = r.indexedDB.open(Iu, Au); - s.onupgradeneeded = function() { - return Tu(s); - }, s.onsuccess = function() { - var o = s.result; - if (e == null) { - var c = o.transaction(zr, "readonly"), l = c.objectStore(zr), u = l.get(r.modelPath); - u.onsuccess = function() { - if (u.result == null) - return o.close(), a(new Error("Cannot find model with path '" + r.modelPath + "' in IndexedDB.")); - i(u.result.modelArtifacts); - }, u.onerror = function(g) { - return o.close(), a(u.error); - }, c.oncomplete = function() { - return o.close(); + }; + BrowserIndexedDB2.prototype.databaseAction = function(modelPath, modelArtifacts) { + var _this = this; + return new Promise(function(resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function() { + return setUpDatabase(openRequest); + }; + openRequest.onsuccess = function() { + var db = openRequest.result; + if (modelArtifacts == null) { + var modelTx = db.transaction(MODEL_STORE_NAME, "readonly"); + var modelStore = modelTx.objectStore(MODEL_STORE_NAME); + var getRequest_1 = modelStore.get(_this.modelPath); + getRequest_1.onsuccess = function() { + if (getRequest_1.result == null) { + db.close(); + return reject(new Error("Cannot find model with path '" + _this.modelPath + "' in IndexedDB.")); + } else { + resolve(getRequest_1.result.modelArtifacts); + } + }; + getRequest_1.onerror = function(error) { + db.close(); + return reject(getRequest_1.error); + }; + modelTx.oncomplete = function() { + return db.close(); }; } else { - var h = Ia(e), d = o.transaction(dr, "readwrite"), p = d.objectStore(dr), f = p.put({modelPath: r.modelPath, modelArtifactsInfo: h}), m; - f.onsuccess = function() { - m = o.transaction(zr, "readwrite"); - var g = m.objectStore(zr), y = g.put({modelPath: r.modelPath, modelArtifacts: e, modelArtifactsInfo: h}); - y.onsuccess = function() { - return i({modelArtifactsInfo: h}); - }, y.onerror = function(w) { - p = d.objectStore(dr); - var b = p.delete(r.modelPath); - b.onsuccess = function() { - return o.close(), a(y.error); - }, b.onerror = function(L) { - return o.close(), a(y.error); + var modelArtifactsInfo_1 = getModelArtifactsInfoForJSON(modelArtifacts); + var infoTx_1 = db.transaction(INFO_STORE_NAME, "readwrite"); + var infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); + var putInfoRequest_1 = infoStore_1.put({modelPath: _this.modelPath, modelArtifactsInfo: modelArtifactsInfo_1}); + var modelTx_1; + putInfoRequest_1.onsuccess = function() { + modelTx_1 = db.transaction(MODEL_STORE_NAME, "readwrite"); + var modelStore2 = modelTx_1.objectStore(MODEL_STORE_NAME); + var putModelRequest = modelStore2.put({ + modelPath: _this.modelPath, + modelArtifacts, + modelArtifactsInfo: modelArtifactsInfo_1 + }); + putModelRequest.onsuccess = function() { + return resolve({modelArtifactsInfo: modelArtifactsInfo_1}); + }; + putModelRequest.onerror = function(error) { + infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); + var deleteInfoRequest = infoStore_1.delete(_this.modelPath); + deleteInfoRequest.onsuccess = function() { + db.close(); + return reject(putModelRequest.error); + }; + deleteInfoRequest.onerror = function(error2) { + db.close(); + return reject(putModelRequest.error); }; }; - }, f.onerror = function(g) { - return o.close(), a(f.error); - }, d.oncomplete = function() { - m == null ? o.close() : m.oncomplete = function() { - return o.close(); - }; + }; + putInfoRequest_1.onerror = function(error) { + db.close(); + return reject(putInfoRequest_1.error); + }; + infoTx_1.oncomplete = function() { + if (modelTx_1 == null) { + db.close(); + } else { + modelTx_1.oncomplete = function() { + return db.close(); + }; + } }; } - }, s.onerror = function(o) { - return a(s.error); + }; + openRequest.onerror = function(error) { + return reject(openRequest.error); }; }); - }, n.URL_SCHEME = "indexeddb://", n; - }(), Bm = function(n) { - return Ge().getBool("IS_BROWSER") && (!Array.isArray(n) && n.startsWith(Ai.URL_SCHEME)) ? wS(n.slice(Ai.URL_SCHEME.length)) : null; - }; - rn.registerSaveRouter(Bm); - rn.registerLoadRouter(Bm); - function wS(n) { - return new Ai(n); - } - function bS(n) { - return n.startsWith(Ai.URL_SCHEME) ? n.slice(Ai.URL_SCHEME.length) : n; - } - var xS = function() { - function n() { - this.indexedDB = Um(); - } - return n.prototype.listModels = function() { - return pe(this, void 0, void 0, function() { - var t = this; - return fe(this, function(e) { - return [2, new Promise(function(r, i) { - var a = t.indexedDB.open(Iu, Au); - a.onupgradeneeded = function() { - return Tu(a); - }, a.onsuccess = function() { - var s = a.result, o = s.transaction(dr, "readonly"), c = o.objectStore(dr), l = c.getAll(); - l.onsuccess = function() { - for (var u = {}, h = 0, d = l.result; h < d.length; h++) { - var p = d[h]; - u[p.modelPath] = p.modelArtifactsInfo; - } - r(u); - }, l.onerror = function(u) { - return s.close(), i(l.error); - }, o.oncomplete = function() { - return s.close(); - }; - }, a.onerror = function(s) { - return i(a.error); - }; - })]; - }); - }); - }, n.prototype.removeModel = function(t) { - return pe(this, void 0, void 0, function() { - var e = this; - return fe(this, function(r) { - return t = bS(t), [2, new Promise(function(i, a) { - var s = e.indexedDB.open(Iu, Au); - s.onupgradeneeded = function() { - return Tu(s); - }, s.onsuccess = function() { - var o = s.result, c = o.transaction(dr, "readwrite"), l = c.objectStore(dr), u = l.get(t), h; - u.onsuccess = function() { - if (u.result == null) - return o.close(), a(new Error("Cannot find model with path '" + t + "' in IndexedDB.")); - var d = l.delete(t), p = function() { - h = o.transaction(zr, "readwrite"); - var f = h.objectStore(zr), m = f.delete(t); - m.onsuccess = function() { - return i(u.result.modelArtifactsInfo); - }, m.onerror = function(g) { - return a(u.error); - }; - }; - d.onsuccess = p, d.onerror = function(f) { - return p(), o.close(), a(u.error); - }; - }, u.onerror = function(d) { - return o.close(), a(u.error); - }, c.oncomplete = function() { - h == null ? o.close() : h.oncomplete = function() { - return o.close(); - }; - }; - }, s.onerror = function(o) { - return a(s.error); - }; - })]; - }); - }); - }, n; + }; + BrowserIndexedDB2.URL_SCHEME = "indexeddb://"; + return BrowserIndexedDB2; }(); - var Kn = "/", Ti = "tensorflowjs_models", zm = "info", LS = "model_topology", SS = "weight_specs", IS = "weight_data", AS = "model_metadata"; - function Pm(n) { - return {info: [Ti, n, zm].join(Kn), topology: [Ti, n, LS].join(Kn), weightSpecs: [Ti, n, SS].join(Kn), weightData: [Ti, n, IS].join(Kn), modelMetadata: [Ti, n, AS].join(Kn)}; - } - function TS(n) { - var t = n.split(Kn); - if (t.length < 3) - throw new Error("Invalid key format: " + n); - return t.slice(1, t.length - 1).join(Kn); - } - function NS(n) { - return n.startsWith(Ni.URL_SCHEME) ? n.slice(Ni.URL_SCHEME.length) : n; - } - var Ni = function() { - function n(t) { - if (!Ge().getBool("IS_BROWSER") || typeof window == "undefined" || typeof window.localStorage == "undefined") - throw new Error("The current environment does not support local storage."); - if (this.LS = window.localStorage, t == null || !t) - throw new Error("For local storage, modelPath must not be null, undefined or empty."); - this.modelPath = t, this.keys = Pm(this.modelPath); + var indexedDBRouter = function(url) { + if (!env().getBool("IS_BROWSER")) { + return null; + } else { + if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { + return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length)); + } else { + return null; + } } - return n.prototype.save = function(t) { - return pe(this, void 0, void 0, function() { - var e, r, i; - return fe(this, function(a) { - if (t.modelTopology instanceof ArrayBuffer) + }; + IORouterRegistry.registerSaveRouter(indexedDBRouter); + IORouterRegistry.registerLoadRouter(indexedDBRouter); + function browserIndexedDB(modelPath) { + return new BrowserIndexedDB(modelPath); + } + function maybeStripScheme(key) { + return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? key.slice(BrowserIndexedDB.URL_SCHEME.length) : key; + } + var BrowserIndexedDBManager = function() { + function BrowserIndexedDBManager2() { + this.indexedDB = getIndexedDBFactory(); + } + BrowserIndexedDBManager2.prototype.listModels = function() { + return __awaiter(this, void 0, void 0, function() { + var _this = this; + return __generator(this, function(_a) { + return [2, new Promise(function(resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function() { + return setUpDatabase(openRequest); + }; + openRequest.onsuccess = function() { + var db = openRequest.result; + var tx = db.transaction(INFO_STORE_NAME, "readonly"); + var store = tx.objectStore(INFO_STORE_NAME); + var getAllInfoRequest = store.getAll(); + getAllInfoRequest.onsuccess = function() { + var out = {}; + for (var _i2 = 0, _a2 = getAllInfoRequest.result; _i2 < _a2.length; _i2++) { + var item = _a2[_i2]; + out[item.modelPath] = item.modelArtifactsInfo; + } + resolve(out); + }; + getAllInfoRequest.onerror = function(error) { + db.close(); + return reject(getAllInfoRequest.error); + }; + tx.oncomplete = function() { + return db.close(); + }; + }; + openRequest.onerror = function(error) { + return reject(openRequest.error); + }; + })]; + }); + }); + }; + BrowserIndexedDBManager2.prototype.removeModel = function(path) { + return __awaiter(this, void 0, void 0, function() { + var _this = this; + return __generator(this, function(_a) { + path = maybeStripScheme(path); + return [2, new Promise(function(resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function() { + return setUpDatabase(openRequest); + }; + openRequest.onsuccess = function() { + var db = openRequest.result; + var infoTx = db.transaction(INFO_STORE_NAME, "readwrite"); + var infoStore = infoTx.objectStore(INFO_STORE_NAME); + var getInfoRequest = infoStore.get(path); + var modelTx; + getInfoRequest.onsuccess = function() { + if (getInfoRequest.result == null) { + db.close(); + return reject(new Error("Cannot find model with path '" + path + "' in IndexedDB.")); + } else { + var deleteInfoRequest = infoStore.delete(path); + var deleteModelData_1 = function() { + modelTx = db.transaction(MODEL_STORE_NAME, "readwrite"); + var modelStore = modelTx.objectStore(MODEL_STORE_NAME); + var deleteModelRequest = modelStore.delete(path); + deleteModelRequest.onsuccess = function() { + return resolve(getInfoRequest.result.modelArtifactsInfo); + }; + deleteModelRequest.onerror = function(error) { + return reject(getInfoRequest.error); + }; + }; + deleteInfoRequest.onsuccess = deleteModelData_1; + deleteInfoRequest.onerror = function(error) { + deleteModelData_1(); + db.close(); + return reject(getInfoRequest.error); + }; + } + }; + getInfoRequest.onerror = function(error) { + db.close(); + return reject(getInfoRequest.error); + }; + infoTx.oncomplete = function() { + if (modelTx == null) { + db.close(); + } else { + modelTx.oncomplete = function() { + return db.close(); + }; + } + }; + }; + openRequest.onerror = function(error) { + return reject(openRequest.error); + }; + })]; + }); + }); + }; + return BrowserIndexedDBManager2; + }(); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PATH_SEPARATOR = "/"; + var PATH_PREFIX = "tensorflowjs_models"; + var INFO_SUFFIX = "info"; + var MODEL_TOPOLOGY_SUFFIX = "model_topology"; + var WEIGHT_SPECS_SUFFIX = "weight_specs"; + var WEIGHT_DATA_SUFFIX = "weight_data"; + var MODEL_METADATA_SUFFIX = "model_metadata"; + function getModelKeys(path) { + return { + info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), + topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), + weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), + weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), + modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) + }; + } + function getModelPathFromKey(key) { + var items = key.split(PATH_SEPARATOR); + if (items.length < 3) { + throw new Error("Invalid key format: " + key); + } + return items.slice(1, items.length - 1).join(PATH_SEPARATOR); + } + function maybeStripScheme$1(key) { + return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? key.slice(BrowserLocalStorage.URL_SCHEME.length) : key; + } + var BrowserLocalStorage = function() { + function BrowserLocalStorage2(modelPath) { + if (!env().getBool("IS_BROWSER") || typeof window === "undefined" || typeof window.localStorage === "undefined") { + throw new Error("The current environment does not support local storage."); + } + this.LS = window.localStorage; + if (modelPath == null || !modelPath) { + throw new Error("For local storage, modelPath must not be null, undefined or empty."); + } + this.modelPath = modelPath; + this.keys = getModelKeys(this.modelPath); + } + BrowserLocalStorage2.prototype.save = function(modelArtifacts) { + return __awaiter(this, void 0, void 0, function() { + var topology, weightSpecs, modelArtifactsInfo; + return __generator(this, function(_a) { + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet."); - e = JSON.stringify(t.modelTopology), r = JSON.stringify(t.weightSpecs), i = Ia(t); - try { - return this.LS.setItem(this.keys.info, JSON.stringify(i)), this.LS.setItem(this.keys.topology, e), this.LS.setItem(this.keys.weightSpecs, r), this.LS.setItem(this.keys.weightData, uS(t.weightData)), this.LS.setItem(this.keys.modelMetadata, JSON.stringify({format: t.format, generatedBy: t.generatedBy, convertedBy: t.convertedBy, userDefinedMetadata: t.userDefinedMetadata})), [2, {modelArtifactsInfo: i}]; - } catch (s) { - throw this.LS.removeItem(this.keys.info), this.LS.removeItem(this.keys.topology), this.LS.removeItem(this.keys.weightSpecs), this.LS.removeItem(this.keys.weightData), this.LS.removeItem(this.keys.modelMetadata), new Error("Failed to save model '" + this.modelPath + "' to local storage: size quota being exceeded is a possible cause of this failure: " + ("modelTopologyBytes=" + i.modelTopologyBytes + ", ") + ("weightSpecsBytes=" + i.weightSpecsBytes + ", ") + ("weightDataBytes=" + i.weightDataBytes + ".")); + } else { + topology = JSON.stringify(modelArtifacts.modelTopology); + weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); + modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); + try { + this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); + this.LS.setItem(this.keys.topology, topology); + this.LS.setItem(this.keys.weightSpecs, weightSpecs); + this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData)); + this.LS.setItem(this.keys.modelMetadata, JSON.stringify({ + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + userDefinedMetadata: modelArtifacts.userDefinedMetadata + })); + return [2, {modelArtifactsInfo}]; + } catch (err) { + this.LS.removeItem(this.keys.info); + this.LS.removeItem(this.keys.topology); + this.LS.removeItem(this.keys.weightSpecs); + this.LS.removeItem(this.keys.weightData); + this.LS.removeItem(this.keys.modelMetadata); + throw new Error("Failed to save model '" + this.modelPath + "' to local storage: size quota being exceeded is a possible cause of this failure: " + ("modelTopologyBytes=" + modelArtifactsInfo.modelTopologyBytes + ", ") + ("weightSpecsBytes=" + modelArtifactsInfo.weightSpecsBytes + ", ") + ("weightDataBytes=" + modelArtifactsInfo.weightDataBytes + ".")); + } } return [2]; }); }); - }, n.prototype.load = function() { - return pe(this, void 0, void 0, function() { - var t, e, r, i, a, s, o; - return fe(this, function(c) { - if (t = JSON.parse(this.LS.getItem(this.keys.info)), t == null) + }; + BrowserLocalStorage2.prototype.load = function() { + return __awaiter(this, void 0, void 0, function() { + var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64; + return __generator(this, function(_a) { + info = JSON.parse(this.LS.getItem(this.keys.info)); + if (info == null) { throw new Error("In local storage, there is no model with name '" + this.modelPath + "'"); - if (t.modelTopologyType !== "JSON") + } + if (info.modelTopologyType !== "JSON") { throw new Error("BrowserLocalStorage does not support loading non-JSON model topology yet."); - if (e = {}, r = JSON.parse(this.LS.getItem(this.keys.topology)), r == null) + } + out = {}; + topology = JSON.parse(this.LS.getItem(this.keys.topology)); + if (topology == null) { throw new Error("In local storage, the topology of model '" + this.modelPath + "' is missing."); - if (e.modelTopology = r, i = JSON.parse(this.LS.getItem(this.keys.weightSpecs)), i == null) + } + out.modelTopology = topology; + weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); + if (weightSpecs == null) { throw new Error("In local storage, the weight specs of model '" + this.modelPath + "' are missing."); - if (e.weightSpecs = i, a = this.LS.getItem(this.keys.modelMetadata), a != null && (s = JSON.parse(a), e.format = s.format, e.generatedBy = s.generatedBy, e.convertedBy = s.convertedBy, e.userDefinedMetadata = s.userDefinedMetadata), o = this.LS.getItem(this.keys.weightData), o == null) + } + out.weightSpecs = weightSpecs; + metadataString = this.LS.getItem(this.keys.modelMetadata); + if (metadataString != null) { + metadata = JSON.parse(metadataString); + out.format = metadata["format"]; + out.generatedBy = metadata["generatedBy"]; + out.convertedBy = metadata["convertedBy"]; + out.userDefinedMetadata = metadata["userDefinedMetadata"]; + } + weightDataBase64 = this.LS.getItem(this.keys.weightData); + if (weightDataBase64 == null) { throw new Error("In local storage, the binary weight values of model " + ("'" + this.modelPath + "' are missing.")); - return e.weightData = hS(o), [2, e]; + } + out.weightData = base64StringToArrayBuffer(weightDataBase64); + return [2, out]; }); }); - }, n.URL_SCHEME = "localstorage://", n; - }(), Mm = function(n) { - return Ge().getBool("IS_BROWSER") && (!Array.isArray(n) && n.startsWith(Ni.URL_SCHEME)) ? _S(n.slice(Ni.URL_SCHEME.length)) : null; - }; - rn.registerSaveRouter(Mm); - rn.registerLoadRouter(Mm); - function _S(n) { - return new Ni(n); - } - var CS = function() { - function n() { - E(Ge().getBool("IS_BROWSER"), function() { - return "Current environment is not a web browser"; - }), E(typeof window == "undefined" || typeof window.localStorage != "undefined", function() { - return "Current browser does not appear to support localStorage"; - }), this.LS = window.localStorage; - } - return n.prototype.listModels = function() { - return pe(this, void 0, void 0, function() { - var t, e, r, i, a, s; - return fe(this, function(o) { - for (t = {}, e = Ti + Kn, r = Kn + zm, i = 0; i < this.LS.length; ++i) - a = this.LS.key(i), a.startsWith(e) && a.endsWith(r) && (s = TS(a), t[s] = JSON.parse(this.LS.getItem(a))); - return [2, t]; - }); - }); - }, n.prototype.removeModel = function(t) { - return pe(this, void 0, void 0, function() { - var e, r; - return fe(this, function(i) { - if (t = NS(t), e = Pm(t), this.LS.getItem(e.info) == null) - throw new Error("Cannot find model at path '" + t + "'"); - return r = JSON.parse(this.LS.getItem(e.info)), this.LS.removeItem(e.info), this.LS.removeItem(e.topology), this.LS.removeItem(e.weightSpecs), this.LS.removeItem(e.weightData), [2, r]; - }); - }); - }, n; + }; + BrowserLocalStorage2.URL_SCHEME = "localstorage://"; + return BrowserLocalStorage2; }(); - var _i = "://", pr = function() { - function n() { + var localStorageRouter = function(url) { + if (!env().getBool("IS_BROWSER")) { + return null; + } else { + if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { + return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length)); + } else { + return null; + } + } + }; + IORouterRegistry.registerSaveRouter(localStorageRouter); + IORouterRegistry.registerLoadRouter(localStorageRouter); + function browserLocalStorage(modelPath) { + return new BrowserLocalStorage(modelPath); + } + var BrowserLocalStorageManager = function() { + function BrowserLocalStorageManager2() { + assert(env().getBool("IS_BROWSER"), function() { + return "Current environment is not a web browser"; + }); + assert(typeof window === "undefined" || typeof window.localStorage !== "undefined", function() { + return "Current browser does not appear to support localStorage"; + }); + this.LS = window.localStorage; + } + BrowserLocalStorageManager2.prototype.listModels = function() { + return __awaiter(this, void 0, void 0, function() { + var out, prefix, suffix, i, key, modelPath; + return __generator(this, function(_a) { + out = {}; + prefix = PATH_PREFIX + PATH_SEPARATOR; + suffix = PATH_SEPARATOR + INFO_SUFFIX; + for (i = 0; i < this.LS.length; ++i) { + key = this.LS.key(i); + if (key.startsWith(prefix) && key.endsWith(suffix)) { + modelPath = getModelPathFromKey(key); + out[modelPath] = JSON.parse(this.LS.getItem(key)); + } + } + return [2, out]; + }); + }); + }; + BrowserLocalStorageManager2.prototype.removeModel = function(path) { + return __awaiter(this, void 0, void 0, function() { + var keys, info; + return __generator(this, function(_a) { + path = maybeStripScheme$1(path); + keys = getModelKeys(path); + if (this.LS.getItem(keys.info) == null) { + throw new Error("Cannot find model at path '" + path + "'"); + } + info = JSON.parse(this.LS.getItem(keys.info)); + this.LS.removeItem(keys.info); + this.LS.removeItem(keys.topology); + this.LS.removeItem(keys.weightSpecs); + this.LS.removeItem(keys.weightData); + return [2, info]; + }); + }); + }; + return BrowserLocalStorageManager2; + }(); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var URL_SCHEME_SUFFIX = "://"; + var ModelStoreManagerRegistry = function() { + function ModelStoreManagerRegistry2() { this.managers = {}; } - return n.getInstance = function() { - return n.instance == null && (n.instance = new n()), n.instance; - }, n.registerManager = function(t, e) { - E(t != null, function() { + ModelStoreManagerRegistry2.getInstance = function() { + if (ModelStoreManagerRegistry2.instance == null) { + ModelStoreManagerRegistry2.instance = new ModelStoreManagerRegistry2(); + } + return ModelStoreManagerRegistry2.instance; + }; + ModelStoreManagerRegistry2.registerManager = function(scheme, manager) { + assert(scheme != null, function() { return "scheme must not be undefined or null."; - }), t.endsWith(_i) && (t = t.slice(0, t.indexOf(_i))), E(t.length > 0, function() { + }); + if (scheme.endsWith(URL_SCHEME_SUFFIX)) { + scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX)); + } + assert(scheme.length > 0, function() { return "scheme must not be an empty string."; }); - var r = n.getInstance(); - E(r.managers[t] == null, function() { - return "A model store manager is already registered for scheme '" + t + "'."; - }), r.managers[t] = e; - }, n.getManager = function(t) { - var e = this.getInstance().managers[t]; - if (e == null) - throw new Error("Cannot find model manager for scheme '" + t + "'"); - return e; - }, n.getSchemes = function() { + var registry = ModelStoreManagerRegistry2.getInstance(); + assert(registry.managers[scheme] == null, function() { + return "A model store manager is already registered for scheme '" + scheme + "'."; + }); + registry.managers[scheme] = manager; + }; + ModelStoreManagerRegistry2.getManager = function(scheme) { + var manager = this.getInstance().managers[scheme]; + if (manager == null) { + throw new Error("Cannot find model manager for scheme '" + scheme + "'"); + } + return manager; + }; + ModelStoreManagerRegistry2.getSchemes = function() { return Object.keys(this.getInstance().managers); - }, n; + }; + return ModelStoreManagerRegistry2; }(); - function Us(n) { - if (n.indexOf(_i) === -1) - throw new Error("The url string provided does not contain a scheme. Supported schemes are: " + ("" + pr.getSchemes().join(","))); - return {scheme: n.split(_i)[0], path: n.split(_i)[1]}; + function parseURL(url) { + if (url.indexOf(URL_SCHEME_SUFFIX) === -1) { + throw new Error("The url string provided does not contain a scheme. Supported schemes are: " + ("" + ModelStoreManagerRegistry.getSchemes().join(","))); + } + return { + scheme: url.split(URL_SCHEME_SUFFIX)[0], + path: url.split(URL_SCHEME_SUFFIX)[1] + }; } - function Hm(n, t, e) { - return e === void 0 && (e = false), pe(this, void 0, void 0, function() { - var r, i, a, s, o, c, l, u, h; - return fe(this, function(d) { - switch (d.label) { + function cloneModelInternal(sourceURL, destURL, deleteSource) { + if (deleteSource === void 0) { + deleteSource = false; + } + return __awaiter(this, void 0, void 0, function() { + var loadHandlers, loadHandler, saveHandlers, saveHandler, sourceScheme, sourcePath, sameMedium, modelArtifacts, saveResult; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - return E(n !== t, function() { - return "Old path and new path are the same: '" + n + "'"; - }), r = rn.getLoadHandlers(n), E(r.length > 0, function() { - return "Copying failed because no load handler is found for source URL " + n + "."; - }), E(r.length < 2, function() { - return "Copying failed because more than one (" + r.length + ") " + ("load handlers for source URL " + n + "."); - }), i = r[0], a = rn.getSaveHandlers(t), E(a.length > 0, function() { - return "Copying failed because no save handler is found for destination " + ("URL " + t + "."); - }), E(a.length < 2, function() { - return "Copying failed because more than one (" + r.length + ") " + ("save handlers for destination URL " + t + "."); - }), s = a[0], o = Us(n).scheme, c = Us(n).path, l = o === Us(n).scheme, [4, i.load()]; + assert(sourceURL !== destURL, function() { + return "Old path and new path are the same: '" + sourceURL + "'"; + }); + loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL); + assert(loadHandlers.length > 0, function() { + return "Copying failed because no load handler is found for source URL " + sourceURL + "."; + }); + assert(loadHandlers.length < 2, function() { + return "Copying failed because more than one (" + loadHandlers.length + ") " + ("load handlers for source URL " + sourceURL + "."); + }); + loadHandler = loadHandlers[0]; + saveHandlers = IORouterRegistry.getSaveHandlers(destURL); + assert(saveHandlers.length > 0, function() { + return "Copying failed because no save handler is found for destination " + ("URL " + destURL + "."); + }); + assert(saveHandlers.length < 2, function() { + return "Copying failed because more than one (" + loadHandlers.length + ") " + ("save handlers for destination URL " + destURL + "."); + }); + saveHandler = saveHandlers[0]; + sourceScheme = parseURL(sourceURL).scheme; + sourcePath = parseURL(sourceURL).path; + sameMedium = sourceScheme === parseURL(sourceURL).scheme; + return [4, loadHandler.load()]; case 1: - return u = d.sent(), e && l ? [4, pr.getManager(o).removeModel(c)] : [3, 3]; + modelArtifacts = _a.sent(); + if (!(deleteSource && sameMedium)) + return [3, 3]; + return [4, ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath)]; case 2: - d.sent(), d.label = 3; + _a.sent(); + _a.label = 3; case 3: - return [4, s.save(u)]; + return [4, saveHandler.save(modelArtifacts)]; case 4: - return h = d.sent(), e && !l ? [4, pr.getManager(o).removeModel(c)] : [3, 6]; + saveResult = _a.sent(); + if (!(deleteSource && !sameMedium)) + return [3, 6]; + return [4, ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath)]; case 5: - d.sent(), d.label = 6; + _a.sent(); + _a.label = 6; case 6: - return [2, h.modelArtifactsInfo]; + return [2, saveResult.modelArtifactsInfo]; } }); }); } - function RS() { - return pe(this, void 0, void 0, function() { - var n, t, e, r, i, a, s, o; - return fe(this, function(c) { - switch (c.label) { + function listModels() { + return __awaiter(this, void 0, void 0, function() { + var schemes, out, _i2, schemes_1, scheme, schemeOut, path, url; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - n = pr.getSchemes(), t = {}, e = 0, r = n, c.label = 1; + schemes = ModelStoreManagerRegistry.getSchemes(); + out = {}; + _i2 = 0, schemes_1 = schemes; + _a.label = 1; case 1: - return e < r.length ? (i = r[e], [4, pr.getManager(i).listModels()]) : [3, 4]; + if (!(_i2 < schemes_1.length)) + return [3, 4]; + scheme = schemes_1[_i2]; + return [4, ModelStoreManagerRegistry.getManager(scheme).listModels()]; case 2: - a = c.sent(); - for (s in a) - o = i + _i + s, t[o] = a[s]; - c.label = 3; + schemeOut = _a.sent(); + for (path in schemeOut) { + url = scheme + URL_SCHEME_SUFFIX + path; + out[url] = schemeOut[path]; + } + _a.label = 3; case 3: - return e++, [3, 1]; + _i2++; + return [3, 1]; case 4: - return [2, t]; + return [2, out]; } }); }); } - function OS(n) { - return pe(this, void 0, void 0, function() { - var t, e; - return fe(this, function(r) { - return t = Us(n), e = pr.getManager(t.scheme), [2, e.removeModel(t.path)]; + function removeModel(url) { + return __awaiter(this, void 0, void 0, function() { + var schemeAndPath, manager; + return __generator(this, function(_a) { + schemeAndPath = parseURL(url); + manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme); + return [2, manager.removeModel(schemeAndPath.path)]; }); }); } - function ES(n, t) { - return pe(this, void 0, void 0, function() { - var e; - return fe(this, function(r) { - return e = false, [2, Hm(n, t, e)]; + function copyModel(sourceURL, destURL) { + return __awaiter(this, void 0, void 0, function() { + var deleteSource; + return __generator(this, function(_a) { + deleteSource = false; + return [2, cloneModelInternal(sourceURL, destURL, deleteSource)]; }); }); } - function DS(n, t) { - return pe(this, void 0, void 0, function() { - var e; - return fe(this, function(r) { - return e = true, [2, Hm(n, t, e)]; + function moveModel(sourceURL, destURL) { + return __awaiter(this, void 0, void 0, function() { + var deleteSource; + return __generator(this, function(_a) { + deleteSource = true; + return [2, cloneModelInternal(sourceURL, destURL, deleteSource)]; }); }); } - var kS = function() { - function n() { + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PlatformBrowser = function() { + function PlatformBrowser2() { } - return n.prototype.fetch = function(t, e) { - return fetch(t, e); - }, n.prototype.now = function() { + PlatformBrowser2.prototype.fetch = function(path, init) { + return fetch(path, init); + }; + PlatformBrowser2.prototype.now = function() { return performance.now(); - }, n.prototype.encode = function(t, e) { - if (e !== "utf-8" && e !== "utf8") - throw new Error("Browser's encoder only supports utf-8, but got " + e); - return this.textEncoder == null && (this.textEncoder = new TextEncoder()), this.textEncoder.encode(t); - }, n.prototype.decode = function(t, e) { - return new TextDecoder(e).decode(t); - }, n; + }; + PlatformBrowser2.prototype.encode = function(text, encoding) { + if (encoding !== "utf-8" && encoding !== "utf8") { + throw new Error("Browser's encoder only supports utf-8, but got " + encoding); + } + if (this.textEncoder == null) { + this.textEncoder = new TextEncoder(); + } + return this.textEncoder.encode(text); + }; + PlatformBrowser2.prototype.decode = function(bytes, encoding) { + return new TextDecoder(encoding).decode(bytes); + }; + return PlatformBrowser2; }(); - if (Ge().get("IS_BROWSER")) { - Ge().setPlatform("browser", new kS()); + if (env().get("IS_BROWSER")) { + env().setPlatform("browser", new PlatformBrowser()); try { - pr.registerManager(Ni.URL_SCHEME, new CS()); - } catch (n) { + ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager()); + } catch (err) { } try { - pr.registerManager(Ai.URL_SCHEME, new xS()); - } catch (n) { + ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager()); + } catch (err) { } } - var FS = {importFetch: function() { - return sf(); - }}, Nu, WS = function() { - function n() { - this.util = of(), this.textEncoder = new this.util.TextEncoder(); + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var getNodeFetch = { + importFetch: function() { + return require_browser(); } - return n.prototype.fetch = function(t, e) { - return Ge().global.fetch != null ? Ge().global.fetch(t, e) : (Nu == null && (Nu = FS.importFetch()), Nu(t, e)); - }, n.prototype.now = function() { - var t = process.hrtime(); - return t[0] * 1e3 + t[1] / 1e6; - }, n.prototype.encode = function(t, e) { - if (e !== "utf-8" && e !== "utf8") - throw new Error("Node built-in encoder only supports utf-8, but got " + e); - return this.textEncoder.encode(t); - }, n.prototype.decode = function(t, e) { - return t.length === 0 ? "" : new this.util.TextDecoder(e).decode(t); - }, n; + }; + var systemFetch; + var PlatformNode = function() { + function PlatformNode2() { + this.util = require_util(); + this.textEncoder = new this.util.TextEncoder(); + } + PlatformNode2.prototype.fetch = function(path, requestInits) { + if (env().global.fetch != null) { + return env().global.fetch(path, requestInits); + } + if (systemFetch == null) { + systemFetch = getNodeFetch.importFetch(); + } + return systemFetch(path, requestInits); + }; + PlatformNode2.prototype.now = function() { + var time2 = process.hrtime(); + return time2[0] * 1e3 + time2[1] / 1e6; + }; + PlatformNode2.prototype.encode = function(text, encoding) { + if (encoding !== "utf-8" && encoding !== "utf8") { + throw new Error("Node built-in encoder only supports utf-8, but got " + encoding); + } + return this.textEncoder.encode(text); + }; + PlatformNode2.prototype.decode = function(bytes, encoding) { + if (bytes.length === 0) { + return ""; + } + return new this.util.TextDecoder(encoding).decode(bytes); + }; + return PlatformNode2; }(); - Ge().get("IS_NODE") && Ge().setPlatform("node", new WS()); - function On(n, t, e) { - return t === void 0 && (t = "float32"), t = t || "float32", vc(n), new ks(n, t, e); + if (env().get("IS_NODE")) { + env().setPlatform("node", new PlatformNode()); } - function US(n, t) { - var e = R(n, "x", "cast"); - if (!ff(t)) - throw new Error("Failed to cast to unknown dtype " + t); - if (t === "string" && e.dtype !== "string" || t !== "string" && e.dtype === "string") + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function buffer(shape, dtype, values) { + if (dtype === void 0) { + dtype = "float32"; + } + dtype = dtype || "float32"; + assertNonNegativeIntegerDimensions(shape); + return new TensorBuffer(shape, dtype, values); + } + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function cast_(x, dtype) { + var $x = convertToTensor(x, "x", "cast"); + if (!isValidDtype(dtype)) { + throw new Error("Failed to cast to unknown dtype " + dtype); + } + if (dtype === "string" && $x.dtype !== "string" || dtype !== "string" && $x.dtype === "string") { throw new Error("Only strings can be casted to strings"); - var r = {x: e}, i = {dtype: t}; - return z.runKernelFunc(function(a) { - return a.cast(e, t); - }, r, null, Rs, i); - } - var he = U({cast_: US}); - function BS(n) { - var t = R(n, "x", "clone", null), e = function() { - return z.makeTensorFromDataId(t.dataId, t.shape, t.dtype); - }, r = {x: t}; - return z.runKernelFunc(e, r, null, il); - } - var Pr = U({clone_: BS}); - function Vm(n, t) { - t === void 0 && (t = false), console.log(n.toString(t)); - } - Cm(); - var zS = {buffer: On, cast: he, clone: Pr, print: Vm}; - jL(zS); - var PS = "model", MS = ".json", HS = ".weights.bin"; - function Gm(n) { - return new Promise(function(t) { - return setTimeout(t); - }).then(n); - } - var _u = function() { - function n(t) { - if (!Ge().getBool("IS_BROWSER")) - throw new Error("browserDownloads() cannot proceed because the current environment is not a browser."); - t.startsWith(n.URL_SCHEME) && (t = t.slice(n.URL_SCHEME.length)), (t == null || t.length === 0) && (t = PS), this.modelTopologyFileName = t + MS, this.weightDataFileName = t + HS; } - return n.prototype.save = function(t) { - return pe(this, void 0, void 0, function() { - var e, r, i, a, s, o; - return fe(this, function(c) { - switch (c.label) { + var inputs = {x: $x}; + var attrs = {dtype}; + return ENGINE.runKernelFunc(function(backend2) { + return backend2.cast($x, dtype); + }, inputs, null, Cast, attrs); + } + var cast = op({cast_}); + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function clone_(x) { + var $x = convertToTensor(x, "x", "clone", null); + var forward = function() { + return ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype); + }; + var inputs = {x: $x}; + return ENGINE.runKernelFunc(forward, inputs, null, Identity); + } + var clone = op({clone_}); + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function print(x, verbose) { + if (verbose === void 0) { + verbose = false; + } + console.log(x.toString(verbose)); + } + /** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + getOrMakeEngine(); + var opHandler$1 = { + buffer, + cast, + clone, + print + }; + setOpHandler(opHandler$1); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DEFAULT_FILE_NAME_PREFIX = "model"; + var DEFAULT_JSON_EXTENSION_NAME = ".json"; + var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = ".weights.bin"; + function defer(f) { + return new Promise(function(resolve) { + return setTimeout(resolve); + }).then(f); + } + var BrowserDownloads = function() { + function BrowserDownloads2(fileNamePrefix) { + if (!env().getBool("IS_BROWSER")) { + throw new Error("browserDownloads() cannot proceed because the current environment is not a browser."); + } + if (fileNamePrefix.startsWith(BrowserDownloads2.URL_SCHEME)) { + fileNamePrefix = fileNamePrefix.slice(BrowserDownloads2.URL_SCHEME.length); + } + if (fileNamePrefix == null || fileNamePrefix.length === 0) { + fileNamePrefix = DEFAULT_FILE_NAME_PREFIX; + } + this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME; + this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME; + } + BrowserDownloads2.prototype.save = function(modelArtifacts) { + return __awaiter(this, void 0, void 0, function() { + var weightsURL, weightsManifest, modelTopologyAndWeightManifest, modelTopologyAndWeightManifestURL, jsonAnchor_1, weightDataAnchor_1; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - if (typeof document == "undefined") + if (typeof document === "undefined") { throw new Error("Browser downloads are not supported in this environment since `document` is not present"); - if (e = window.URL.createObjectURL(new Blob([t.weightData], {type: "application/octet-stream"})), !(t.modelTopology instanceof ArrayBuffer)) + } + weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], {type: "application/octet-stream"})); + if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) return [3, 1]; throw new Error("BrowserDownloads.save() does not support saving model topology in binary formats yet."); case 1: - return r = [{paths: ["./" + this.weightDataFileName], weights: t.weightSpecs}], i = {modelTopology: t.modelTopology, format: t.format, generatedBy: t.generatedBy, convertedBy: t.convertedBy, weightsManifest: r}, a = window.URL.createObjectURL(new Blob([JSON.stringify(i)], {type: "application/json"})), s = this.jsonAnchor == null ? document.createElement("a") : this.jsonAnchor, s.download = this.modelTopologyFileName, s.href = a, [4, Gm(function() { - return s.dispatchEvent(new MouseEvent("click")); + weightsManifest = [{ + paths: ["./" + this.weightDataFileName], + weights: modelArtifacts.weightSpecs + }]; + modelTopologyAndWeightManifest = { + modelTopology: modelArtifacts.modelTopology, + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + weightsManifest + }; + modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {type: "application/json"})); + jsonAnchor_1 = this.jsonAnchor == null ? document.createElement("a") : this.jsonAnchor; + jsonAnchor_1.download = this.modelTopologyFileName; + jsonAnchor_1.href = modelTopologyAndWeightManifestURL; + return [4, defer(function() { + return jsonAnchor_1.dispatchEvent(new MouseEvent("click")); })]; case 2: - return c.sent(), t.weightData != null ? (o = this.weightDataAnchor == null ? document.createElement("a") : this.weightDataAnchor, o.download = this.weightDataFileName, o.href = e, [4, Gm(function() { - return o.dispatchEvent(new MouseEvent("click")); - })]) : [3, 4]; + _a.sent(); + if (!(modelArtifacts.weightData != null)) + return [3, 4]; + weightDataAnchor_1 = this.weightDataAnchor == null ? document.createElement("a") : this.weightDataAnchor; + weightDataAnchor_1.download = this.weightDataFileName; + weightDataAnchor_1.href = weightsURL; + return [4, defer(function() { + return weightDataAnchor_1.dispatchEvent(new MouseEvent("click")); + })]; case 3: - c.sent(), c.label = 4; + _a.sent(); + _a.label = 4; case 4: - return [2, {modelArtifactsInfo: Ia(t)}]; + return [2, {modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts)}]; } }); }); - }, n.URL_SCHEME = "downloads://", n; - }(), VS = function() { - function n(t) { - if (t == null || t.length < 1) - throw new Error("When calling browserFiles, at least 1 file is required, " + ("but received " + t)); - this.files = t; + }; + BrowserDownloads2.URL_SCHEME = "downloads://"; + return BrowserDownloads2; + }(); + var BrowserFiles = function() { + function BrowserFiles2(files) { + if (files == null || files.length < 1) { + throw new Error("When calling browserFiles, at least 1 file is required, " + ("but received " + files)); + } + this.files = files; } - return n.prototype.load = function() { - return pe(this, void 0, void 0, function() { - var t, e, r = this; - return fe(this, function(i) { - return t = this.files[0], e = this.files.slice(1), [2, new Promise(function(a, s) { - var o = new FileReader(); - o.onload = function(c) { - var l = JSON.parse(c.target.result), u = l.modelTopology; - if (u == null) { - s(new Error("modelTopology field is missing from file " + t.name)); + BrowserFiles2.prototype.load = function() { + return __awaiter(this, void 0, void 0, function() { + var jsonFile, weightFiles; + var _this = this; + return __generator(this, function(_a) { + jsonFile = this.files[0]; + weightFiles = this.files.slice(1); + return [2, new Promise(function(resolve, reject) { + var jsonReader = new FileReader(); + jsonReader.onload = function(event) { + var modelJSON = JSON.parse(event.target.result); + var modelTopology = modelJSON.modelTopology; + if (modelTopology == null) { + reject(new Error("modelTopology field is missing from file " + jsonFile.name)); return; } - e.length === 0 && a({modelTopology: u}); - var h = l.weightsManifest; - if (h == null) { - s(new Error("weightManifest field is missing from file " + t.name)); + if (weightFiles.length === 0) { + resolve({modelTopology}); + } + var weightsManifest = modelJSON.weightsManifest; + if (weightsManifest == null) { + reject(new Error("weightManifest field is missing from file " + jsonFile.name)); return; } - var d; + var pathToFile; try { - d = r.checkManifestAndWeightFiles(h, e); - } catch (g) { - s(g); + pathToFile = _this.checkManifestAndWeightFiles(weightsManifest, weightFiles); + } catch (err) { + reject(err); return; } - var p = [], f = [], m = []; - h.forEach(function(g) { - g.paths.forEach(function(y) { - f.push(y), m.push(null); - }), p.push.apply(p, g.weights); - }), h.forEach(function(g) { - g.paths.forEach(function(y) { - var w = new FileReader(); - w.onload = function(b) { - var L = b.target.result, x = f.indexOf(y); - m[x] = L, m.indexOf(null) === -1 && a({modelTopology: u, weightSpecs: p, weightData: Su(m), format: l.format, generatedBy: l.generatedBy, convertedBy: l.convertedBy, userDefinedMetadata: l.userDefinedMetadata}); - }, w.onerror = function(b) { - return s("Failed to weights data from file of path '" + y + "'."); - }, w.readAsArrayBuffer(d[y]); + var weightSpecs = []; + var paths = []; + var perFileBuffers = []; + weightsManifest.forEach(function(weightsGroup) { + weightsGroup.paths.forEach(function(path) { + paths.push(path); + perFileBuffers.push(null); + }); + weightSpecs.push.apply(weightSpecs, weightsGroup.weights); + }); + weightsManifest.forEach(function(weightsGroup) { + weightsGroup.paths.forEach(function(path) { + var weightFileReader = new FileReader(); + weightFileReader.onload = function(event2) { + var weightData = event2.target.result; + var index = paths.indexOf(path); + perFileBuffers[index] = weightData; + if (perFileBuffers.indexOf(null) === -1) { + resolve({ + modelTopology, + weightSpecs, + weightData: concatenateArrayBuffers(perFileBuffers), + format: modelJSON.format, + generatedBy: modelJSON.generatedBy, + convertedBy: modelJSON.convertedBy, + userDefinedMetadata: modelJSON.userDefinedMetadata + }); + } + }; + weightFileReader.onerror = function(error) { + return reject("Failed to weights data from file of path '" + path + "'."); + }; + weightFileReader.readAsArrayBuffer(pathToFile[path]); }); }); - }, o.onerror = function(c) { - return s("Failed to read model topology and weights manifest JSON " + ("from file '" + t.name + "'. BrowserFiles supports loading ") + "Keras-style tf.Model artifacts only."); - }, o.readAsText(t); + }; + jsonReader.onerror = function(error) { + return reject("Failed to read model topology and weights manifest JSON " + ("from file '" + jsonFile.name + "'. BrowserFiles supports loading ") + "Keras-style tf.Model artifacts only."); + }; + jsonReader.readAsText(jsonFile); })]; }); }); - }, n.prototype.checkManifestAndWeightFiles = function(t, e) { - for (var r = [], i = e.map(function(l) { - return Wm(l.name); - }), a = {}, s = 0, o = t; s < o.length; s++) { - var c = o[s]; - c.paths.forEach(function(l) { - var u = Wm(l); - if (r.indexOf(u) !== -1) - throw new Error("Duplicate file basename found in weights manifest: " + ("'" + u + "'")); - if (r.push(u), i.indexOf(u) === -1) - throw new Error("Weight file with basename '" + u + "' is not provided."); - a[l] = e[i.indexOf(u)]; + }; + BrowserFiles2.prototype.checkManifestAndWeightFiles = function(manifest, files) { + var basenames = []; + var fileNames = files.map(function(file) { + return basename(file.name); + }); + var pathToFile = {}; + for (var _i2 = 0, manifest_1 = manifest; _i2 < manifest_1.length; _i2++) { + var group = manifest_1[_i2]; + group.paths.forEach(function(path) { + var pathBasename = basename(path); + if (basenames.indexOf(pathBasename) !== -1) { + throw new Error("Duplicate file basename found in weights manifest: " + ("'" + pathBasename + "'")); + } + basenames.push(pathBasename); + if (fileNames.indexOf(pathBasename) === -1) { + throw new Error("Weight file with basename '" + pathBasename + "' is not provided."); + } else { + pathToFile[path] = files[fileNames.indexOf(pathBasename)]; + } }); } - if (r.length !== e.length) - throw new Error("Mismatch in the number of files in weights manifest " + ("(" + r.length + ") and the number of weight files provided ") + ("(" + e.length + ").")); - return a; - }, n; - }(), qS = function(n) { - return Ge().getBool("IS_BROWSER") && (!Array.isArray(n) && n.startsWith(_u.URL_SCHEME)) ? GS(n.slice(_u.URL_SCHEME.length)) : null; - }; - rn.registerSaveRouter(qS); - function GS(n) { - return n === void 0 && (n = "model"), new _u(n); - } - function YS(n) { - return new VS(n); - } - function qm(n, t, e, r) { - s(n), e = e == null ? 0 : e, r = r == null ? 1 : r, o(e, r); - var i = 0, a = function(c) { - return c.then(function(l) { - var u = e + ++i / n.length * (r - e); - return t(u), l; - }), c; + if (basenames.length !== files.length) { + throw new Error("Mismatch in the number of files in weights manifest " + ("(" + basenames.length + ") and the number of weight files provided ") + ("(" + files.length + ").")); + } + return pathToFile; }; - function s(c) { - E(c != null && Array.isArray(c) && c.length > 0, function() { + return BrowserFiles2; + }(); + var browserDownloadsRouter = function(url) { + if (!env().getBool("IS_BROWSER")) { + return null; + } else { + if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { + return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length)); + } else { + return null; + } + } + }; + IORouterRegistry.registerSaveRouter(browserDownloadsRouter); + function browserDownloads(fileNamePrefix) { + if (fileNamePrefix === void 0) { + fileNamePrefix = "model"; + } + return new BrowserDownloads(fileNamePrefix); + } + function browserFiles(files) { + return new BrowserFiles(files); + } + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) { + checkPromises(promises); + startFraction = startFraction == null ? 0 : startFraction; + endFraction = endFraction == null ? 1 : endFraction; + checkFraction(startFraction, endFraction); + var resolvedPromise = 0; + var registerMonitor = function(promise) { + promise.then(function(value) { + var fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction); + onProgress(fraction); + return value; + }); + return promise; + }; + function checkPromises(promises2) { + assert(promises2 != null && Array.isArray(promises2) && promises2.length > 0, function() { return "promises must be a none empty array"; }); } - function o(c, l) { - E(c >= 0 && c <= 1, function() { - return "Progress fraction must be in range [0, 1], but " + ("got startFraction " + c); - }), E(l >= 0 && l <= 1, function() { - return "Progress fraction must be in range [0, 1], but " + ("got endFraction " + l); - }), E(l >= c, function() { - return "startFraction must be no more than endFraction, but " + ("got startFraction " + c + " and endFraction ") + ("" + l); + function checkFraction(startFraction2, endFraction2) { + assert(startFraction2 >= 0 && startFraction2 <= 1, function() { + return "Progress fraction must be in range [0, 1], but " + ("got startFraction " + startFraction2); + }); + assert(endFraction2 >= 0 && endFraction2 <= 1, function() { + return "Progress fraction must be in range [0, 1], but " + ("got endFraction " + endFraction2); + }); + assert(endFraction2 >= startFraction2, function() { + return "startFraction must be no more than endFraction, but " + ("got startFraction " + startFraction2 + " and endFraction ") + ("" + endFraction2); }); } - return Promise.all(n.map(a)); + return Promise.all(promises.map(registerMonitor)); } - function Ym(n, t) { - return pe(this, void 0, void 0, function() { - var e, r, i, a, s, o, c, l, u, h, d; - return fe(this, function(p) { - switch (p.label) { + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { + return __awaiter(this, void 0, void 0, function() { + var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _a, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _b; + return __generator(this, function(_c) { + switch (_c.label) { case 0: - return t == null && (t = {}), e = t.fetchFunc == null ? Ge().platform.fetch : t.fetchFunc, r = n.map(function(f) { - return e(f, t.requestInit, {isBinary: true}); - }), i = 0, a = 0.5, t.onProgress == null ? [4, Promise.all(r)] : [3, 2]; + if (loadOptions == null) { + loadOptions = {}; + } + fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc; + requests = fetchURLs.map(function(fetchURL) { + return fetchFunc(fetchURL, loadOptions.requestInit, {isBinary: true}); + }); + fetchStartFraction = 0; + fetchEndFraction = 0.5; + if (!(loadOptions.onProgress == null)) + return [3, 2]; + return [4, Promise.all(requests)]; case 1: - return o = p.sent(), [3, 4]; + _a = _c.sent(); + return [3, 4]; case 2: - return [4, qm(r, t.onProgress, i, a)]; + return [4, monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)]; case 3: - o = p.sent(), p.label = 4; + _a = _c.sent(); + _c.label = 4; case 4: - return s = o, c = s.map(function(f) { - return f.arrayBuffer(); - }), l = 0.5, u = 1, t.onProgress == null ? [4, Promise.all(c)] : [3, 6]; + responses = _a; + bufferPromises = responses.map(function(response) { + return response.arrayBuffer(); + }); + bufferStartFraction = 0.5; + bufferEndFraction = 1; + if (!(loadOptions.onProgress == null)) + return [3, 6]; + return [4, Promise.all(bufferPromises)]; case 5: - return d = p.sent(), [3, 8]; + _b = _c.sent(); + return [3, 8]; case 6: - return [4, qm(c, t.onProgress, l, u)]; + return [4, monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)]; case 7: - d = p.sent(), p.label = 8; + _b = _c.sent(); + _c.label = 8; case 8: - return h = d, [2, h]; + buffers = _b; + return [2, buffers]; } }); }); } - function KS(n, t, e, r) { - return t === void 0 && (t = ""), pe(this, void 0, void 0, function() { - var i, a; - return fe(this, function(s) { - return i = function(o) { - return Ym(o, {requestInit: r}); - }, a = Km(i), [2, a(n, t, e)]; + function loadWeights(manifest, filePathPrefix, weightNames, requestInit) { + if (filePathPrefix === void 0) { + filePathPrefix = ""; + } + return __awaiter(this, void 0, void 0, function() { + var fetchWeights, loadWeights2; + return __generator(this, function(_a) { + fetchWeights = function(fetchUrls) { + return loadWeightsAsArrayBuffer(fetchUrls, {requestInit}); + }; + loadWeights2 = weightsLoaderFactory(fetchWeights); + return [2, loadWeights2(manifest, filePathPrefix, weightNames)]; }); }); } - function Km(n) { - var t = this; - return function(e, r, i) { - return r === void 0 && (r = ""), pe(t, void 0, void 0, function() { - var a, s, o, c, l, u, h, d, p, f; - return fe(this, function(m) { - switch (m.label) { + function weightsLoaderFactory(fetchWeightsFunction) { + var _this = this; + return function(manifest, filePathPrefix, weightNames) { + if (filePathPrefix === void 0) { + filePathPrefix = ""; + } + return __awaiter(_this, void 0, void 0, function() { + var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - if (a = e.map(function() { + groupIndicesToFetchMap = manifest.map(function() { return false; - }), s = {}, o = i != null ? i.map(function() { + }); + groupWeightsToFetch = {}; + weightsFound = weightNames != null ? weightNames.map(function() { return false; - }) : [], c = [], e.forEach(function(g, y) { - var w = 0; - g.weights.forEach(function(b) { - var L = "quantization" in b ? b.quantization.dtype : b.dtype, x = xu[L] * pt(b.shape), N = function() { - a[y] = true, s[y] == null && (s[y] = []), s[y].push({manifestEntry: b, groupOffset: w, sizeBytes: x}); + }) : []; + allManifestWeightNames = []; + manifest.forEach(function(manifestGroupConfig, groupIndex) { + var groupOffset = 0; + manifestGroupConfig.weights.forEach(function(weightsEntry) { + var rawDtype = "quantization" in weightsEntry ? weightsEntry.quantization.dtype : weightsEntry.dtype; + var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * sizeFromShape(weightsEntry.shape); + var enqueueWeightsForFetchingFn = function() { + groupIndicesToFetchMap[groupIndex] = true; + if (groupWeightsToFetch[groupIndex] == null) { + groupWeightsToFetch[groupIndex] = []; + } + groupWeightsToFetch[groupIndex].push({ + manifestEntry: weightsEntry, + groupOffset, + sizeBytes: weightsBytes + }); }; - i != null ? i.forEach(function(I, C) { - I === b.name && (N(), o[C] = true); - }) : N(), c.push(b.name), w += x; + if (weightNames != null) { + weightNames.forEach(function(weightName, weightIndex) { + if (weightName === weightsEntry.name) { + enqueueWeightsForFetchingFn(); + weightsFound[weightIndex] = true; + } + }); + } else { + enqueueWeightsForFetchingFn(); + } + allManifestWeightNames.push(weightsEntry.name); + groupOffset += weightsBytes; }); - }), !o.every(function(g) { - return g; - })) - throw l = i.filter(function(g, y) { - return !o[y]; - }), new Error("Could not find weights in manifest with names: " + (l.join(", ") + `. -`) + "Manifest JSON has weights with names: " + (c.join(", ") + ".")); - return u = a.reduce(function(g, y, w) { - return y && g.push(w), g; - }, []), h = [], u.forEach(function(g) { - e[g].paths.forEach(function(y) { - var w = r + (r.endsWith("/") ? "" : "/") + y; - h.push(w); + }); + if (!weightsFound.every(function(found) { + return found; + })) { + weightsNotFound = weightNames.filter(function(_, i) { + return !weightsFound[i]; }); - }), [4, n(h)]; - case 1: - return d = m.sent(), p = {}, f = 0, u.forEach(function(g) { - for (var y = e[g].paths.length, w = 0, b = 0; b < y; b++) - w += d[f + b].byteLength; - for (var L = new ArrayBuffer(w), x = new Uint8Array(L), N = 0, I = 0; I < y; I++) { - var C = new Uint8Array(d[f + I]); - x.set(C, N), N += C.byteLength; + throw new Error("Could not find weights in manifest with names: " + (weightsNotFound.join(", ") + ". \n") + "Manifest JSON has weights with names: " + (allManifestWeightNames.join(", ") + ".")); + } + groupIndicesToFetch = groupIndicesToFetchMap.reduce(function(accumulator, shouldFetch, i) { + if (shouldFetch) { + accumulator.push(i); } - var O = s[g]; - O.forEach(function(D) { - var F = L.slice(D.groupOffset, D.groupOffset + D.sizeBytes), k = km(F, [D.manifestEntry]); - for (var B in k) - p[B] = k[B]; - }), f += y; - }), [2, p]; + return accumulator; + }, []); + fetchUrls = []; + groupIndicesToFetch.forEach(function(i) { + manifest[i].paths.forEach(function(filepath) { + var fetchUrl = filePathPrefix + (!filePathPrefix.endsWith("/") ? "/" : "") + filepath; + fetchUrls.push(fetchUrl); + }); + }); + return [4, fetchWeightsFunction(fetchUrls)]; + case 1: + buffers = _a.sent(); + weightsTensorMap = {}; + bufferIndexOffset = 0; + groupIndicesToFetch.forEach(function(i) { + var numBuffers = manifest[i].paths.length; + var groupBytes = 0; + for (var i_1 = 0; i_1 < numBuffers; i_1++) { + groupBytes += buffers[bufferIndexOffset + i_1].byteLength; + } + var groupBuffer = new ArrayBuffer(groupBytes); + var groupByteBuffer = new Uint8Array(groupBuffer); + var groupBufferOffset = 0; + for (var i_2 = 0; i_2 < numBuffers; i_2++) { + var buffer2 = new Uint8Array(buffers[bufferIndexOffset + i_2]); + groupByteBuffer.set(buffer2, groupBufferOffset); + groupBufferOffset += buffer2.byteLength; + } + var weightsEntries = groupWeightsToFetch[i]; + weightsEntries.forEach(function(weightsEntry) { + var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); + var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); + for (var name_1 in nameToTensorMap) { + weightsTensorMap[name_1] = nameToTensorMap[name_1]; + } + }); + bufferIndexOffset += numBuffers; + }); + return [2, weightsTensorMap]; } }); }); }; } - var jS = "application/octet-stream", $S = "application/json", jm = function() { - function n(t, e) { - if (this.DEFAULT_METHOD = "POST", e == null && (e = {}), this.weightPathPrefix = e.weightPathPrefix, this.onProgress = e.onProgress, this.weightUrlConverter = e.weightUrlConverter, e.fetchFunc != null ? (E(typeof e.fetchFunc == "function", function() { - return "Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)"; - }), this.fetch = e.fetchFunc) : this.fetch = Ge().platform.fetch, E(t != null && t.length > 0, function() { + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var OCTET_STREAM_MIME_TYPE = "application/octet-stream"; + var JSON_TYPE = "application/json"; + var HTTPRequest = function() { + function HTTPRequest2(path, loadOptions) { + this.DEFAULT_METHOD = "POST"; + if (loadOptions == null) { + loadOptions = {}; + } + this.weightPathPrefix = loadOptions.weightPathPrefix; + this.onProgress = loadOptions.onProgress; + this.weightUrlConverter = loadOptions.weightUrlConverter; + if (loadOptions.fetchFunc != null) { + assert(typeof loadOptions.fetchFunc === "function", function() { + return "Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)"; + }); + this.fetch = loadOptions.fetchFunc; + } else { + this.fetch = env().platform.fetch; + } + assert(path != null && path.length > 0, function() { return "URL path for http must not be null, undefined or empty."; - }), Array.isArray(t) && E(t.length === 2, function() { - return "URL paths for http must have a length of 2, " + ("(actual length is " + t.length + ")."); - }), this.path = t, e.requestInit != null && e.requestInit.body != null) + }); + if (Array.isArray(path)) { + assert(path.length === 2, function() { + return "URL paths for http must have a length of 2, " + ("(actual length is " + path.length + ")."); + }); + } + this.path = path; + if (loadOptions.requestInit != null && loadOptions.requestInit.body != null) { throw new Error("requestInit is expected to have no pre-existing body, but has one."); - this.requestInit = e.requestInit || {}; + } + this.requestInit = loadOptions.requestInit || {}; } - return n.prototype.save = function(t) { - return pe(this, void 0, void 0, function() { - var e, r, i, a; - return fe(this, function(s) { - switch (s.label) { + HTTPRequest2.prototype.save = function(modelArtifacts) { + return __awaiter(this, void 0, void 0, function() { + var init, weightsManifest, modelTopologyAndWeightManifest, response; + return __generator(this, function(_a) { + switch (_a.label) { case 0: - if (t.modelTopology instanceof ArrayBuffer) + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("BrowserHTTPRequest.save() does not support saving model topology in binary formats yet."); - return e = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit), e.body = new FormData(), r = [{paths: ["./model.weights.bin"], weights: t.weightSpecs}], i = {modelTopology: t.modelTopology, format: t.format, generatedBy: t.generatedBy, convertedBy: t.convertedBy, userDefinedMetadata: t.userDefinedMetadata, weightsManifest: r}, e.body.append("model.json", new Blob([JSON.stringify(i)], {type: $S}), "model.json"), t.weightData != null && e.body.append("model.weights.bin", new Blob([t.weightData], {type: jS}), "model.weights.bin"), [4, this.fetch(this.path, e)]; + } + init = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit); + init.body = new FormData(); + weightsManifest = [{ + paths: ["./model.weights.bin"], + weights: modelArtifacts.weightSpecs + }]; + modelTopologyAndWeightManifest = { + modelTopology: modelArtifacts.modelTopology, + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + userDefinedMetadata: modelArtifacts.userDefinedMetadata, + weightsManifest + }; + init.body.append("model.json", new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {type: JSON_TYPE}), "model.json"); + if (modelArtifacts.weightData != null) { + init.body.append("model.weights.bin", new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}), "model.weights.bin"); + } + return [4, this.fetch(this.path, init)]; case 1: - if (a = s.sent(), a.ok) - return [2, {modelArtifactsInfo: Ia(t), responses: [a]}]; - throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + (a.status + ".")); + response = _a.sent(); + if (response.ok) { + return [2, { + modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts), + responses: [response] + }]; + } else { + throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + (response.status + ".")); + } } }); }); - }, n.prototype.load = function() { - return pe(this, void 0, void 0, function() { - var t, e, r, i, a, s, o, c, l, u, h, d, p, f, m; - return fe(this, function(g) { - switch (g.label) { + }; + HTTPRequest2.prototype.load = function() { + return __awaiter(this, void 0, void 0, function() { + var modelConfigRequest, modelConfig, e_1, message, modelTopology, weightsManifest, generatedBy, convertedBy, format, userDefinedMetadata, weightSpecs, weightData, results, artifacts, initializer; + return __generator(this, function(_a) { + switch (_a.label) { case 0: return [4, this.fetch(this.path, this.requestInit)]; case 1: - if (t = g.sent(), !t.ok) - throw new Error("Request to " + this.path + " failed with status code " + (t.status + ". Please verify this URL points to ") + "the model JSON of the model to load."); - g.label = 2; + modelConfigRequest = _a.sent(); + if (!modelConfigRequest.ok) { + throw new Error("Request to " + this.path + " failed with status code " + (modelConfigRequest.status + ". Please verify this URL points to ") + "the model JSON of the model to load."); + } + _a.label = 2; case 2: - return g.trys.push([2, 4, , 5]), [4, t.json()]; + _a.trys.push([2, 4, , 5]); + return [4, modelConfigRequest.json()]; case 3: - return e = g.sent(), [3, 5]; + modelConfig = _a.sent(); + return [3, 5]; case 4: - throw r = g.sent(), i = "Failed to parse model JSON of response from " + this.path + ".", this.path.endsWith(".pb") ? i += " Your path contains a .pb file extension. Support for .pb models have been removed in TensorFlow.js 1.0 in favor of .json models. You can re-convert your Python TensorFlow model using the TensorFlow.js 1.0 conversion scripts or you can convert your.pb models with the 'pb2json'NPM script in the tensorflow/tfjs-converter repository." : i += " Please make sure the server is serving valid JSON for this request.", new Error(i); + e_1 = _a.sent(); + message = "Failed to parse model JSON of response from " + this.path + "."; + if (this.path.endsWith(".pb")) { + message += " Your path contains a .pb file extension. Support for .pb models have been removed in TensorFlow.js 1.0 in favor of .json models. You can re-convert your Python TensorFlow model using the TensorFlow.js 1.0 conversion scripts or you can convert your.pb models with the 'pb2json'NPM script in the tensorflow/tfjs-converter repository."; + } else { + message += " Please make sure the server is serving valid JSON for this request."; + } + throw new Error(message); case 5: - if (a = e.modelTopology, s = e.weightsManifest, o = e.generatedBy, c = e.convertedBy, l = e.format, u = e.userDefinedMetadata, a == null && s == null) + modelTopology = modelConfig.modelTopology; + weightsManifest = modelConfig.weightsManifest; + generatedBy = modelConfig.generatedBy; + convertedBy = modelConfig.convertedBy; + format = modelConfig.format; + userDefinedMetadata = modelConfig.userDefinedMetadata; + if (modelTopology == null && weightsManifest == null) { throw new Error("The JSON from HTTP path " + this.path + " contains neither model topology or manifest for weights."); - return s != null ? [4, this.loadWeights(s)] : [3, 7]; + } + if (!(weightsManifest != null)) + return [3, 7]; + return [4, this.loadWeights(weightsManifest)]; case 6: - p = g.sent(), h = p[0], d = p[1], g.label = 7; + results = _a.sent(); + weightSpecs = results[0], weightData = results[1]; + _a.label = 7; case 7: - return f = {modelTopology: a, weightSpecs: h, weightData: d, userDefinedMetadata: u, generatedBy: o, convertedBy: c, format: l}, m = e.modelInitializer, m && (f.modelInitializer = m), [2, f]; + artifacts = { + modelTopology, + weightSpecs, + weightData, + userDefinedMetadata, + generatedBy, + convertedBy, + format + }; + initializer = modelConfig.modelInitializer; + if (initializer) { + artifacts.modelInitializer = initializer; + } + return [2, artifacts]; } }); }); - }, n.prototype.loadWeights = function(t) { - return pe(this, void 0, void 0, function() { - var e, r, i, a, s, o, c, l, u, h, d, p, f, m, g, y, w, b, L, x, N; - return fe(this, function(I) { - switch (I.label) { + }; + HTTPRequest2.prototype.loadWeights = function(weightsManifest) { + return __awaiter(this, void 0, void 0, function() { + var weightPath, _a, prefix, suffix, pathPrefix, weightSpecs, _i2, weightsManifest_1, entry, fetchURLs, urlPromises, _b, weightsManifest_2, weightsGroup, _c, _d, path, _e, _f, _g, buffers; + return __generator(this, function(_h) { + switch (_h.label) { case 0: - for (e = Array.isArray(this.path) ? this.path[1] : this.path, r = XS(e), i = r[0], a = r[1], s = this.weightPathPrefix || i, o = [], c = 0, l = t; c < l.length; c++) - u = l[c], o.push.apply(o, u.weights); - for (h = [], d = [], p = 0, f = t; p < f.length; p++) - for (m = f[p], g = 0, y = m.paths; g < y.length; g++) - w = y[g], this.weightUrlConverter != null ? d.push(this.weightUrlConverter(w)) : h.push(s + w + a); - return this.weightUrlConverter ? (L = (b = h.push).apply, x = [h], [4, Promise.all(d)]) : [3, 2]; + weightPath = Array.isArray(this.path) ? this.path[1] : this.path; + _a = parseUrl(weightPath), prefix = _a[0], suffix = _a[1]; + pathPrefix = this.weightPathPrefix || prefix; + weightSpecs = []; + for (_i2 = 0, weightsManifest_1 = weightsManifest; _i2 < weightsManifest_1.length; _i2++) { + entry = weightsManifest_1[_i2]; + weightSpecs.push.apply(weightSpecs, entry.weights); + } + fetchURLs = []; + urlPromises = []; + for (_b = 0, weightsManifest_2 = weightsManifest; _b < weightsManifest_2.length; _b++) { + weightsGroup = weightsManifest_2[_b]; + for (_c = 0, _d = weightsGroup.paths; _c < _d.length; _c++) { + path = _d[_c]; + if (this.weightUrlConverter != null) { + urlPromises.push(this.weightUrlConverter(path)); + } else { + fetchURLs.push(pathPrefix + path + suffix); + } + } + } + if (!this.weightUrlConverter) + return [3, 2]; + _f = (_e = fetchURLs.push).apply; + _g = [fetchURLs]; + return [4, Promise.all(urlPromises)]; case 1: - L.apply(b, x.concat([I.sent()])), I.label = 2; + _f.apply(_e, _g.concat([_h.sent()])); + _h.label = 2; case 2: - return [4, Ym(h, {requestInit: this.requestInit, fetchFunc: this.fetch, onProgress: this.onProgress})]; + return [4, loadWeightsAsArrayBuffer(fetchURLs, { + requestInit: this.requestInit, + fetchFunc: this.fetch, + onProgress: this.onProgress + })]; case 3: - return N = I.sent(), [2, [o, Su(N)]]; + buffers = _h.sent(); + return [2, [weightSpecs, concatenateArrayBuffers(buffers)]]; } }); }); - }, n.URL_SCHEME_REGEX = /^https?:\/\//, n; + }; + HTTPRequest2.URL_SCHEME_REGEX = /^https?:\/\//; + return HTTPRequest2; }(); - function XS(n) { - var t = n.lastIndexOf("/"), e = n.lastIndexOf("?"), r = n.substring(0, t), i = e > t ? n.substring(e) : ""; - return [r + "/", i]; + function parseUrl(url) { + var lastSlash = url.lastIndexOf("/"); + var lastSearchParam = url.lastIndexOf("?"); + var prefix = url.substring(0, lastSlash); + var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : ""; + return [prefix + "/", suffix]; } - function Cu(n) { - return n.match(jm.URL_SCHEME_REGEX) != null; + function isHTTPScheme(url) { + return url.match(HTTPRequest.URL_SCHEME_REGEX) != null; } - var $m = function(n, t) { - if (typeof fetch == "undefined" && (t == null || t.fetchFunc == null)) + var httpRouter = function(url, loadOptions) { + if (typeof fetch === "undefined" && (loadOptions == null || loadOptions.fetchFunc == null)) { return null; - var e = true; - return Array.isArray(n) ? e = n.every(function(r) { - return Cu(r); - }) : e = Cu(n), e ? Ru(n, t) : null; - }; - rn.registerSaveRouter($m); - rn.registerLoadRouter($m); - function Ru(n, t) { - return new jm(n, t); - } - function JS(n, t) { - return Ru(n, t); - } - var Ou = function() { - function n(t) { - this.modelArtifacts = t; + } else { + var isHTTP = true; + if (Array.isArray(url)) { + isHTTP = url.every(function(urlItem) { + return isHTTPScheme(urlItem); + }); + } else { + isHTTP = isHTTPScheme(url); + } + if (isHTTP) { + return http(url, loadOptions); + } } - return n.prototype.load = function() { - return pe(this, void 0, void 0, function() { - return fe(this, function(t) { + return null; + }; + IORouterRegistry.registerSaveRouter(httpRouter); + IORouterRegistry.registerLoadRouter(httpRouter); + function http(path, loadOptions) { + return new HTTPRequest(path, loadOptions); + } + function browserHTTPRequest(path, loadOptions) { + return http(path, loadOptions); + } + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PassthroughLoader = function() { + function PassthroughLoader2(modelArtifacts) { + this.modelArtifacts = modelArtifacts; + } + PassthroughLoader2.prototype.load = function() { + return __awaiter(this, void 0, void 0, function() { + return __generator(this, function(_a) { return [2, this.modelArtifacts]; }); }); - }, n; - }(), ZS = function() { - function n(t) { - this.saveHandler = t; + }; + return PassthroughLoader2; + }(); + var PassthroughSaver = function() { + function PassthroughSaver2(saveHandler) { + this.saveHandler = saveHandler; } - return n.prototype.save = function(t) { - return pe(this, void 0, void 0, function() { - return fe(this, function(e) { - return [2, this.saveHandler(t)]; + PassthroughSaver2.prototype.save = function(modelArtifacts) { + return __awaiter(this, void 0, void 0, function() { + return __generator(this, function(_a) { + return [2, this.saveHandler(modelArtifacts)]; }); }); - }, n; - }(); - function QS(n, t, e, r) { - if (arguments.length === 1) { - var i = n.modelTopology != null || n.weightSpecs != null; - return i ? new Ou(n) : (console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release."), new Ou({modelTopology: n})); - } else - return console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release."), new Ou({modelTopology: n, weightSpecs: t, weightData: e, trainingConfig: r}); - } - function eI(n) { - return new ZS(n); - } - var tI = {__proto__: null, browserFiles: YS, browserHTTPRequest: JS, concatenateArrayBuffers: Su, decodeWeights: km, encodeWeights: cS, fromMemory: QS, getLoadHandlers: vS, getModelArtifactsInfoForJSON: Ia, getSaveHandlers: yS, http: Ru, isHTTPScheme: Cu, loadWeights: KS, registerLoadRouter: gS, registerSaveRouter: mS, weightsLoaderFactory: Km, withSaveHandler: eI, copyModel: ES, listModels: RS, moveModel: DS, removeModel: OS}; - function nI(n, t) { - var e = R(n, "x", "reshape", null), r = {x: e}, i = {shape: t}, a = function(s, o) { - return t = uf(t, e.size), E(e.size === pt(t), function() { - return "new shape and old shape must have the same number of elements."; - }), o([e]), s.reshape(e, t); }; - return z.runKernelFunc(a, r, null, Cl, i); - } - var Y = U({reshape_: nI}); - function rI(n, t, e, r) { - var i; - e === void 0 && (e = false), r === void 0 && (r = false); - var a = R(n, "a", "matMul"), s = R(t, "b", "matMul"); - i = ct(a, s), a = i[0], s = i[1]; - var o = function(u, h) { - h([a, s]); - var d = e ? a.shape[a.rank - 2] : a.shape[a.rank - 1], p = r ? s.shape[s.rank - 1] : s.shape[s.rank - 2], f = e ? a.shape[a.rank - 1] : a.shape[a.rank - 2], m = r ? s.shape[s.rank - 2] : s.shape[s.rank - 1], g = a.shape.slice(0, -2), y = s.shape.slice(0, -2), w = pt(g), b = pt(y), L = w === b || w === 1 || b === 1; - E(a.rank >= 2 && s.rank >= 2 && L, function() { - return "Error in matMul: the input batch dimensions must either be the same or at least one input batch dimension must be 1. Got input " + ("batch dimensions of (" + g + ") and (" + y + ")."); - }), E(d === p, function() { - return "Error in matMul: inner shapes (" + d + ") and (" + (p + ") of Tensors with shapes " + a.shape + " and ") + (s.shape + " and transposeA=" + e) + (" and transposeB=" + r + " must match."); + return PassthroughSaver2; + }(); + function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) { + if (arguments.length === 1) { + var isModelArtifacts = modelArtifacts.modelTopology != null || modelArtifacts.weightSpecs != null; + if (isModelArtifacts) { + return new PassthroughLoader(modelArtifacts); + } else { + console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release."); + return new PassthroughLoader({modelTopology: modelArtifacts}); + } + } else { + console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release."); + return new PassthroughLoader({ + modelTopology: modelArtifacts, + weightSpecs, + weightData, + trainingConfig }); - var x = w > b ? g : y, N = x.concat([f, m]), I = e ? Y(a, [w, d, f]) : Y(a, [w, f, d]), C = r ? Y(s, [b, m, p]) : Y(s, [b, p, m]), O = u.batchMatMul(I, C, e, r); - return Y(O, N); - }, c = {a, b: s}, l = {transposeA: e, transposeB: r}; - return z.runKernelFunc(o, c, null, kc, l); + } } - var Ue = U({matMul_: rI}); - function iI(n, t, e, r) { - if (e === void 0 && (e = 1), r === void 0 && (r = 0), t < 2) - throw new Error("Error in oneHot: depth must be >=2, but it is " + t); - var i = R(n, "indices", "oneHot", "int32"), a = i.shape.concat([t]), s = function(l, u) { - return u([i]), Y(l.oneHot(Y(i, [i.size]), t, e, r), a); - }, o = {indices: i}, c = {depth: t, onValue: e, offValue: r}; - return z.runKernelFunc(s, o, null, Sl, c); + function withSaveHandler(saveHandler) { + return new PassthroughSaver(saveHandler); } - var Bs = U({oneHot_: iI}); - function aI(n, t) { - var e = R(n, "x", "transpose"); - if (t == null && (t = e.shape.map(function(a, s) { - return s; - }).reverse()), E(e.rank === t.length, function() { - return "Error in transpose: rank of input " + e.rank + " " + ("must match length of perm " + t + "."); - }), t.forEach(function(a) { - E(a >= 0 && a < e.rank, function() { - return "All entries in 'perm' must be between 0 and " + (e.rank - 1) + (" but got " + t); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var io = { + __proto__: null, + browserFiles, + browserHTTPRequest, + concatenateArrayBuffers, + decodeWeights, + encodeWeights, + fromMemory, + getLoadHandlers, + getModelArtifactsInfoForJSON, + getSaveHandlers, + http, + isHTTPScheme, + loadWeights, + registerLoadRouter, + registerSaveRouter, + weightsLoaderFactory, + withSaveHandler, + copyModel, + listModels, + moveModel, + removeModel + }; + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function reshape_(x, shape) { + var $x = convertToTensor(x, "x", "reshape", null); + var inputs = {x: $x}; + var attrs = {shape}; + var forward = function(backend2, save) { + shape = inferFromImplicitShape(shape, $x.size); + assert($x.size === sizeFromShape(shape), function() { + return "new shape and old shape must have the same number of elements."; }); - }), e.rank <= 1) - return e.clone(); - var r = {x: e}, i = {perm: t}; - return z.runKernelFunc(function(a) { - return a.transpose(e, t); - }, r, null, eu, i); + save([$x]); + return backend2.reshape($x, shape); + }; + return ENGINE.runKernelFunc(forward, inputs, null, Reshape, attrs); } - var Tt = U({transpose_: aI}); - function sI(n, t, e) { - var r = R(n, "labels", "confusionMatrix"), i = R(t, "predictions", "confusionMatrix"); - E(e == null || e > 0 && Number.isInteger(e), function() { - return "If provided, numClasses must be a positive integer, " + ("but got " + e); - }), E(r.rank === 1, function() { - return "Expected the rank of labels to be 1, but got " + r.rank; - }), E(i.rank === 1, function() { - return "Expected the rank of predictions to be 1, " + ("but got " + i.rank); - }), E(r.shape[0] === i.shape[0], function() { - return "Mismatch in the number of examples: " + (r.shape[0] + " vs. " + i.shape[0] + ". ") + "Labels and predictions should have the same number of elements."; - }), E(e > 0 && Number.isInteger(e), function() { - return "numClasses is required to be a positive integer, but got " + ("" + e); + var reshape = op({reshape_}); + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function matMul_(a, b, transposeA, transposeB) { + var _a; + if (transposeA === void 0) { + transposeA = false; + } + if (transposeB === void 0) { + transposeB = false; + } + var $a = convertToTensor(a, "a", "matMul"); + var $b = convertToTensor(b, "b", "matMul"); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var forward = function(backend2, save) { + save([$a, $b]); + var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; + var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; + var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; + var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; + var outerDimsA = $a.shape.slice(0, -2); + var outerDimsB = $b.shape.slice(0, -2); + var batchDimA = sizeFromShape(outerDimsA); + var batchDimB = sizeFromShape(outerDimsB); + var batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1; + assert($a.rank >= 2 && $b.rank >= 2 && batchDimsCompatible, function() { + return "Error in matMul: the input batch dimensions must either be the same or at least one input batch dimension must be 1. Got input " + ("batch dimensions of (" + outerDimsA + ") and (" + outerDimsB + ")."); + }); + assert(innerShapeA === innerShapeB, function() { + return "Error in matMul: inner shapes (" + innerShapeA + ") and (" + (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") + ($b.shape + " and transposeA=" + transposeA) + (" and transposeB=" + transposeB + " must match."); + }); + var outShapeOuterDims = batchDimA > batchDimB ? outerDimsA : outerDimsB; + var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); + var a3D = transposeA ? reshape($a, [batchDimA, innerShapeA, outerShapeA]) : reshape($a, [batchDimA, outerShapeA, innerShapeA]); + var b3D = transposeB ? reshape($b, [batchDimB, outerShapeB, innerShapeB]) : reshape($b, [batchDimB, innerShapeB, outerShapeB]); + var res3d = backend2.batchMatMul(a3D, b3D, transposeA, transposeB); + return reshape(res3d, outShape); + }; + var inputs = {a: $a, b: $b}; + var attrs = {transposeA, transposeB}; + return ENGINE.runKernelFunc(forward, inputs, null, BatchMatMul, attrs); + } + var matMul = op({matMul_}); + /** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function oneHot_(indices, depth, onValue, offValue) { + if (onValue === void 0) { + onValue = 1; + } + if (offValue === void 0) { + offValue = 0; + } + if (depth < 2) { + throw new Error("Error in oneHot: depth must be >=2, but it is " + depth); + } + var $indices = convertToTensor(indices, "indices", "oneHot", "int32"); + var outShape = $indices.shape.concat([depth]); + var forward = function(backend2, save) { + save([$indices]); + return reshape(backend2.oneHot(reshape($indices, [$indices.size]), depth, onValue, offValue), outShape); + }; + var inputs = {indices: $indices}; + var attrs = {depth, onValue, offValue}; + return ENGINE.runKernelFunc(forward, inputs, null, OneHot, attrs); + } + var oneHot = op({oneHot_}); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function transpose_(x, perm) { + var $x = convertToTensor(x, "x", "transpose"); + if (perm == null) { + perm = $x.shape.map(function(s, i) { + return i; + }).reverse(); + } + assert($x.rank === perm.length, function() { + return "Error in transpose: rank of input " + $x.rank + " " + ("must match length of perm " + perm + "."); }); - var a = Bs(he(r, "int32"), e), s = Bs(he(i, "int32"), e), o = Tt(a), c = Ue(o, s); - return he(c, "int32"); + perm.forEach(function(axis) { + assert(axis >= 0 && axis < $x.rank, function() { + return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) + (" but got " + perm); + }); + }); + if ($x.rank <= 1) { + return $x.clone(); + } + var inputs = {x: $x}; + var attrs = {perm}; + return ENGINE.runKernelFunc(function(backend2) { + return backend2.transpose($x, perm); + }, inputs, null, Transpose, attrs); } - var oI = U({confusionMatrix_: sI}); - var cI = {__proto__: null, confusionMatrix: oI}; - function Xm(n, t, e) { - if (Ur(n), t != null && t.length !== 3) + var transpose = op({transpose_}); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function confusionMatrix_(labels, predictions, numClasses) { + var $labels = convertToTensor(labels, "labels", "confusionMatrix"); + var $predictions = convertToTensor(predictions, "predictions", "confusionMatrix"); + assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function() { + return "If provided, numClasses must be a positive integer, " + ("but got " + numClasses); + }); + assert($labels.rank === 1, function() { + return "Expected the rank of labels to be 1, but got " + $labels.rank; + }); + assert($predictions.rank === 1, function() { + return "Expected the rank of predictions to be 1, " + ("but got " + $predictions.rank); + }); + assert($labels.shape[0] === $predictions.shape[0], function() { + return "Mismatch in the number of examples: " + ($labels.shape[0] + " vs. " + $predictions.shape[0] + ". ") + "Labels and predictions should have the same number of elements."; + }); + assert(numClasses > 0 && Number.isInteger(numClasses), function() { + return "numClasses is required to be a positive integer, but got " + ("" + numClasses); + }); + var oneHotLabels = oneHot(cast($labels, "int32"), numClasses); + var oneHotPredictions = oneHot(cast($predictions, "int32"), numClasses); + var oneHotLabelsT = transpose(oneHotLabels); + var product = matMul(oneHotLabelsT, oneHotPredictions); + return cast(product, "int32"); + } + var confusionMatrix = op({confusionMatrix_}); + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var math = { + __proto__: null, + confusionMatrix + }; + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function tensor3d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 3) { throw new Error("tensor3d() requires shape to have three numbers"); - var r = Rn(n, e); - if (r.length !== 3 && r.length !== 1) + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 3 && inferredShape.length !== 1) { throw new Error("tensor3d() requires values to be number[][][] or flat/TypedArray"); - if (r.length === 1 && t == null) + } + if (inferredShape.length === 1 && shape == null) { throw new Error("tensor3d() requires shape to be provided when `values` are a flat array"); - return ur(n, t, r, e); + } + return makeTensor(values, shape, inferredShape, dtype); } - var Ci; - function lI(n, t) { - if (t === void 0 && (t = 3), t > 4) + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var fromPixels2DContext; + function fromPixels_(pixels, numChannels) { + if (numChannels === void 0) { + numChannels = 3; + } + if (numChannels > 4) { throw new Error("Cannot construct Tensor with more than 4 channels from pixels."); - if (n == null) + } + if (pixels == null) { throw new Error("pixels passed to tf.browser.fromPixels() can not be null"); - var e = false, r = false, i = false, a = false, s = false; - if (n.data instanceof Uint8Array) - e = true; - else if (typeof ImageData != "undefined" && n instanceof ImageData) - r = true; - else if (typeof HTMLVideoElement != "undefined" && n instanceof HTMLVideoElement) - i = true; - else if (typeof HTMLImageElement != "undefined" && n instanceof HTMLImageElement) - a = true; - else if (n.getContext != null) - s = true; - else - throw new Error("pixels passed to tf.browser.fromPixels() must be either an HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData in browser, or OffscreenCanvas, ImageData in webworker or {data: Uint32Array, width: number, height: number}, " + ("but was " + n.constructor.name)); - if (i) { - var o = 2; - if (i && n.readyState < o) + } + var isPixelData = false; + var isImageData = false; + var isVideo = false; + var isImage = false; + var isCanvasLike = false; + if (pixels.data instanceof Uint8Array) { + isPixelData = true; + } else if (typeof ImageData !== "undefined" && pixels instanceof ImageData) { + isImageData = true; + } else if (typeof HTMLVideoElement !== "undefined" && pixels instanceof HTMLVideoElement) { + isVideo = true; + } else if (typeof HTMLImageElement !== "undefined" && pixels instanceof HTMLImageElement) { + isImage = true; + } else if (pixels.getContext != null) { + isCanvasLike = true; + } else { + throw new Error("pixels passed to tf.browser.fromPixels() must be either an HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData in browser, or OffscreenCanvas, ImageData in webworker or {data: Uint32Array, width: number, height: number}, " + ("but was " + pixels.constructor.name)); + } + if (isVideo) { + var HAVE_CURRENT_DATA_READY_STATE = 2; + if (isVideo && pixels.readyState < HAVE_CURRENT_DATA_READY_STATE) { throw new Error("The video element has not loaded data yet. Please wait for `loadeddata` event on the