var __create = Object.create; var __defProp = Object.defineProperty; var __getProtoOf = Object.getPrototypeOf; var __hasOwnProp = Object.prototype.hasOwnProperty; var __getOwnPropNames = Object.getOwnPropertyNames; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __markAsModule = (target) => __defProp(target, "__esModule", {value: true}); var __commonJS = (callback, module2) => () => { if (!module2) { module2 = {exports: {}}; callback(module2.exports, module2); } return module2.exports; }; var __export = (target, all) => { __markAsModule(target); for (var name in all) __defProp(target, name, {get: all[name], enumerable: true}); }; var __exportStar = (target, module2, desc) => { __markAsModule(target); if (typeof module2 === "object" || typeof module2 === "function") { for (let key of __getOwnPropNames(module2)) if (!__hasOwnProp.call(target, key) && key !== "default") __defProp(target, key, {get: () => module2[key], enumerable: !(desc = __getOwnPropDesc(module2, key)) || desc.enumerable}); } return target; }; var __toModule = (module2) => { if (module2 && module2.__esModule) return module2; return __exportStar(__defProp(__create(__getProtoOf(module2)), "default", {value: module2, enumerable: true}), module2); }; // node_modules/node-fetch/lib/index.mjs var require_lib = __commonJS((exports2) => { __export(exports2, { FetchError: () => FetchError, Headers: () => Headers, Request: () => Request, Response: () => Response, default: () => lib_default }); const stream = __toModule(require("stream")); const http2 = __toModule(require("http")); const url = __toModule(require("url")); const https2 = __toModule(require("https")); const zlib2 = __toModule(require("zlib")); const Readable = stream.default.Readable; const BUFFER = Symbol("buffer"); const TYPE = Symbol("type"); class Blob2 { constructor() { this[TYPE] = ""; const blobParts = arguments[0]; const options = arguments[1]; const buffers = []; let size = 0; if (blobParts) { const a = blobParts; const length = Number(a.length); for (let i = 0; i < length; i++) { const element = a[i]; let buffer2; if (element instanceof Buffer) { buffer2 = element; } else if (ArrayBuffer.isView(element)) { buffer2 = Buffer.from(element.buffer, element.byteOffset, element.byteLength); } else if (element instanceof ArrayBuffer) { buffer2 = Buffer.from(element); } else if (element instanceof Blob2) { buffer2 = element[BUFFER]; } else { buffer2 = Buffer.from(typeof element === "string" ? element : String(element)); } size += buffer2.length; buffers.push(buffer2); } } this[BUFFER] = Buffer.concat(buffers); let type = options && options.type !== void 0 && String(options.type).toLowerCase(); if (type && !/[^\u0020-\u007E]/.test(type)) { this[TYPE] = type; } } get size() { return this[BUFFER].length; } get type() { return this[TYPE]; } text() { return Promise.resolve(this[BUFFER].toString()); } arrayBuffer() { const buf = this[BUFFER]; const ab = buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); return Promise.resolve(ab); } stream() { const readable = new Readable(); readable._read = function() { }; readable.push(this[BUFFER]); readable.push(null); return readable; } toString() { return "[object Blob]"; } slice() { const size = this.size; const start = arguments[0]; const end = arguments[1]; let relativeStart, relativeEnd; if (start === void 0) { relativeStart = 0; } else if (start < 0) { relativeStart = Math.max(size + start, 0); } else { relativeStart = Math.min(start, size); } if (end === void 0) { relativeEnd = size; } else if (end < 0) { relativeEnd = Math.max(size + end, 0); } else { relativeEnd = Math.min(end, size); } const span = Math.max(relativeEnd - relativeStart, 0); const buffer2 = this[BUFFER]; const slicedBuffer = buffer2.slice(relativeStart, relativeStart + span); const blob = new Blob2([], {type: arguments[2]}); blob[BUFFER] = slicedBuffer; return blob; } } Object.defineProperties(Blob2.prototype, { size: {enumerable: true}, type: {enumerable: true}, slice: {enumerable: true} }); Object.defineProperty(Blob2.prototype, Symbol.toStringTag, { value: "Blob", writable: false, enumerable: false, configurable: true }); function FetchError(message, type, systemError) { Error.call(this, message); this.message = message; this.type = type; if (systemError) { this.code = this.errno = systemError.code; } Error.captureStackTrace(this, this.constructor); } FetchError.prototype = Object.create(Error.prototype); FetchError.prototype.constructor = FetchError; FetchError.prototype.name = "FetchError"; let convert; try { convert = require("encoding").convert; } catch (e) { } const INTERNALS = Symbol("Body internals"); const PassThrough = stream.default.PassThrough; function Body(body2) { var _this = this; var _ref = arguments.length > 1 && arguments[1] !== void 0 ? arguments[1] : {}, _ref$size = _ref.size; let size = _ref$size === void 0 ? 0 : _ref$size; var _ref$timeout = _ref.timeout; let timeout = _ref$timeout === void 0 ? 0 : _ref$timeout; if (body2 == null) { body2 = null; } else if (isURLSearchParams(body2)) { body2 = Buffer.from(body2.toString()); } else if (isBlob(body2)) ; else if (Buffer.isBuffer(body2)) ; else if (Object.prototype.toString.call(body2) === "[object ArrayBuffer]") { body2 = Buffer.from(body2); } else if (ArrayBuffer.isView(body2)) { body2 = Buffer.from(body2.buffer, body2.byteOffset, body2.byteLength); } else if (body2 instanceof stream.default) ; else { body2 = Buffer.from(String(body2)); } this[INTERNALS] = { body: body2, disturbed: false, error: null }; this.size = size; this.timeout = timeout; if (body2 instanceof stream.default) { body2.on("error", function(err) { const error = err.name === "AbortError" ? err : new FetchError(`Invalid response body while trying to fetch ${_this.url}: ${err.message}`, "system", err); _this[INTERNALS].error = error; }); } } Body.prototype = { get body() { return this[INTERNALS].body; }, get bodyUsed() { return this[INTERNALS].disturbed; }, arrayBuffer() { return consumeBody.call(this).then(function(buf) { return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); }); }, blob() { let ct = this.headers && this.headers.get("content-type") || ""; return consumeBody.call(this).then(function(buf) { return Object.assign(new Blob2([], { type: ct.toLowerCase() }), { [BUFFER]: buf }); }); }, json() { var _this2 = this; return consumeBody.call(this).then(function(buffer2) { try { return JSON.parse(buffer2.toString()); } catch (err) { return Body.Promise.reject(new FetchError(`invalid json response body at ${_this2.url} reason: ${err.message}`, "invalid-json")); } }); }, text() { return consumeBody.call(this).then(function(buffer2) { return buffer2.toString(); }); }, buffer() { return consumeBody.call(this); }, textConverted() { var _this3 = this; return consumeBody.call(this).then(function(buffer2) { return convertBody(buffer2, _this3.headers); }); } }; Object.defineProperties(Body.prototype, { body: {enumerable: true}, bodyUsed: {enumerable: true}, arrayBuffer: {enumerable: true}, blob: {enumerable: true}, json: {enumerable: true}, text: {enumerable: true} }); Body.mixIn = function(proto) { for (const name of Object.getOwnPropertyNames(Body.prototype)) { if (!(name in proto)) { const desc = Object.getOwnPropertyDescriptor(Body.prototype, name); Object.defineProperty(proto, name, desc); } } }; function consumeBody() { var _this4 = this; if (this[INTERNALS].disturbed) { return Body.Promise.reject(new TypeError(`body used already for: ${this.url}`)); } this[INTERNALS].disturbed = true; if (this[INTERNALS].error) { return Body.Promise.reject(this[INTERNALS].error); } let body2 = this.body; if (body2 === null) { return Body.Promise.resolve(Buffer.alloc(0)); } if (isBlob(body2)) { body2 = body2.stream(); } if (Buffer.isBuffer(body2)) { return Body.Promise.resolve(body2); } if (!(body2 instanceof stream.default)) { return Body.Promise.resolve(Buffer.alloc(0)); } let accum = []; let accumBytes = 0; let abort = false; return new Body.Promise(function(resolve, reject) { let resTimeout; if (_this4.timeout) { resTimeout = setTimeout(function() { abort = true; reject(new FetchError(`Response timeout while trying to fetch ${_this4.url} (over ${_this4.timeout}ms)`, "body-timeout")); }, _this4.timeout); } body2.on("error", function(err) { if (err.name === "AbortError") { abort = true; reject(err); } else { reject(new FetchError(`Invalid response body while trying to fetch ${_this4.url}: ${err.message}`, "system", err)); } }); body2.on("data", function(chunk) { if (abort || chunk === null) { return; } if (_this4.size && accumBytes + chunk.length > _this4.size) { abort = true; reject(new FetchError(`content size at ${_this4.url} over limit: ${_this4.size}`, "max-size")); return; } accumBytes += chunk.length; accum.push(chunk); }); body2.on("end", function() { if (abort) { return; } clearTimeout(resTimeout); try { resolve(Buffer.concat(accum, accumBytes)); } catch (err) { reject(new FetchError(`Could not create Buffer from response body for ${_this4.url}: ${err.message}`, "system", err)); } }); }); } function convertBody(buffer2, headers) { if (typeof convert !== "function") { throw new Error("The package `encoding` must be installed to use the textConverted() function"); } const ct = headers.get("content-type"); let charset = "utf-8"; let res, str; if (ct) { res = /charset=([^;]*)/i.exec(ct); } str = buffer2.slice(0, 1024).toString(); if (!res && str) { res = / 0 && arguments[0] !== void 0 ? arguments[0] : void 0; this[MAP] = Object.create(null); if (init2 instanceof Headers) { const rawHeaders = init2.raw(); const headerNames = Object.keys(rawHeaders); for (const headerName of headerNames) { for (const value of rawHeaders[headerName]) { this.append(headerName, value); } } return; } if (init2 == null) ; else if (typeof init2 === "object") { const method = init2[Symbol.iterator]; if (method != null) { if (typeof method !== "function") { throw new TypeError("Header pairs must be iterable"); } const pairs = []; for (const pair of init2) { if (typeof pair !== "object" || typeof pair[Symbol.iterator] !== "function") { throw new TypeError("Each header pair must be iterable"); } pairs.push(Array.from(pair)); } for (const pair of pairs) { if (pair.length !== 2) { throw new TypeError("Each header pair must be a name/value tuple"); } this.append(pair[0], pair[1]); } } else { for (const key of Object.keys(init2)) { const value = init2[key]; this.append(key, value); } } } else { throw new TypeError("Provided initializer must be an object"); } } get(name) { name = `${name}`; validateName(name); const key = find(this[MAP], name); if (key === void 0) { return null; } return this[MAP][key].join(", "); } forEach(callback) { let thisArg = arguments.length > 1 && arguments[1] !== void 0 ? arguments[1] : void 0; let pairs = getHeaders(this); let i = 0; while (i < pairs.length) { var _pairs$i = pairs[i]; const name = _pairs$i[0], value = _pairs$i[1]; callback.call(thisArg, value, name, this); pairs = getHeaders(this); i++; } } set(name, value) { name = `${name}`; value = `${value}`; validateName(name); validateValue(value); const key = find(this[MAP], name); this[MAP][key !== void 0 ? key : name] = [value]; } append(name, value) { name = `${name}`; value = `${value}`; validateName(name); validateValue(value); const key = find(this[MAP], name); if (key !== void 0) { this[MAP][key].push(value); } else { this[MAP][name] = [value]; } } has(name) { name = `${name}`; validateName(name); return find(this[MAP], name) !== void 0; } delete(name) { name = `${name}`; validateName(name); const key = find(this[MAP], name); if (key !== void 0) { delete this[MAP][key]; } } raw() { return this[MAP]; } keys() { return createHeadersIterator(this, "key"); } values() { return createHeadersIterator(this, "value"); } [Symbol.iterator]() { return createHeadersIterator(this, "key+value"); } } Headers.prototype.entries = Headers.prototype[Symbol.iterator]; Object.defineProperty(Headers.prototype, Symbol.toStringTag, { value: "Headers", writable: false, enumerable: false, configurable: true }); Object.defineProperties(Headers.prototype, { get: {enumerable: true}, forEach: {enumerable: true}, set: {enumerable: true}, append: {enumerable: true}, has: {enumerable: true}, delete: {enumerable: true}, keys: {enumerable: true}, values: {enumerable: true}, entries: {enumerable: true} }); function getHeaders(headers) { let kind = arguments.length > 1 && arguments[1] !== void 0 ? arguments[1] : "key+value"; const keys = Object.keys(headers[MAP]).sort(); return keys.map(kind === "key" ? function(k) { return k.toLowerCase(); } : kind === "value" ? function(k) { return headers[MAP][k].join(", "); } : function(k) { return [k.toLowerCase(), headers[MAP][k].join(", ")]; }); } const INTERNAL = Symbol("internal"); function createHeadersIterator(target, kind) { const iterator = Object.create(HeadersIteratorPrototype); iterator[INTERNAL] = { target, kind, index: 0 }; return iterator; } const HeadersIteratorPrototype = Object.setPrototypeOf({ next() { if (!this || Object.getPrototypeOf(this) !== HeadersIteratorPrototype) { throw new TypeError("Value of `this` is not a HeadersIterator"); } var _INTERNAL = this[INTERNAL]; const target = _INTERNAL.target, kind = _INTERNAL.kind, index = _INTERNAL.index; const values = getHeaders(target, kind); const len = values.length; if (index >= len) { return { value: void 0, done: true }; } this[INTERNAL].index = index + 1; return { value: values[index], done: false }; } }, Object.getPrototypeOf(Object.getPrototypeOf([][Symbol.iterator]()))); Object.defineProperty(HeadersIteratorPrototype, Symbol.toStringTag, { value: "HeadersIterator", writable: false, enumerable: false, configurable: true }); function exportNodeCompatibleHeaders(headers) { const obj = Object.assign({__proto__: null}, headers[MAP]); const hostHeaderKey = find(headers[MAP], "Host"); if (hostHeaderKey !== void 0) { obj[hostHeaderKey] = obj[hostHeaderKey][0]; } return obj; } function createHeadersLenient(obj) { const headers = new Headers(); for (const name of Object.keys(obj)) { if (invalidTokenRegex.test(name)) { continue; } if (Array.isArray(obj[name])) { for (const val of obj[name]) { if (invalidHeaderCharRegex.test(val)) { continue; } if (headers[MAP][name] === void 0) { headers[MAP][name] = [val]; } else { headers[MAP][name].push(val); } } } else if (!invalidHeaderCharRegex.test(obj[name])) { headers[MAP][name] = [obj[name]]; } } return headers; } const INTERNALS$1 = Symbol("Response internals"); const STATUS_CODES = http2.default.STATUS_CODES; class Response { constructor() { let body2 = arguments.length > 0 && arguments[0] !== void 0 ? arguments[0] : null; let opts = arguments.length > 1 && arguments[1] !== void 0 ? arguments[1] : {}; Body.call(this, body2, opts); const status = opts.status || 200; const headers = new Headers(opts.headers); if (body2 != null && !headers.has("Content-Type")) { const contentType = extractContentType(body2); if (contentType) { headers.append("Content-Type", contentType); } } this[INTERNALS$1] = { url: opts.url, status, statusText: opts.statusText || STATUS_CODES[status], headers, counter: opts.counter }; } get url() { return this[INTERNALS$1].url || ""; } get status() { return this[INTERNALS$1].status; } get ok() { return this[INTERNALS$1].status >= 200 && this[INTERNALS$1].status < 300; } get redirected() { return this[INTERNALS$1].counter > 0; } get statusText() { return this[INTERNALS$1].statusText; } get headers() { return this[INTERNALS$1].headers; } clone() { return new Response(clone(this), { url: this.url, status: this.status, statusText: this.statusText, headers: this.headers, ok: this.ok, redirected: this.redirected }); } } Body.mixIn(Response.prototype); Object.defineProperties(Response.prototype, { url: {enumerable: true}, status: {enumerable: true}, ok: {enumerable: true}, redirected: {enumerable: true}, statusText: {enumerable: true}, headers: {enumerable: true}, clone: {enumerable: true} }); Object.defineProperty(Response.prototype, Symbol.toStringTag, { value: "Response", writable: false, enumerable: false, configurable: true }); const INTERNALS$2 = Symbol("Request internals"); const parse_url = url.default.parse; const format_url = url.default.format; const streamDestructionSupported = "destroy" in stream.default.Readable.prototype; function isRequest(input) { return typeof input === "object" && typeof input[INTERNALS$2] === "object"; } function isAbortSignal(signal) { const proto = signal && typeof signal === "object" && Object.getPrototypeOf(signal); return !!(proto && proto.constructor.name === "AbortSignal"); } class Request { constructor(input) { let init2 = arguments.length > 1 && arguments[1] !== void 0 ? arguments[1] : {}; let parsedURL; if (!isRequest(input)) { if (input && input.href) { parsedURL = parse_url(input.href); } else { parsedURL = parse_url(`${input}`); } input = {}; } else { parsedURL = parse_url(input.url); } let method = init2.method || input.method || "GET"; method = method.toUpperCase(); if ((init2.body != null || isRequest(input) && input.body !== null) && (method === "GET" || method === "HEAD")) { throw new TypeError("Request with GET/HEAD method cannot have body"); } let inputBody = init2.body != null ? init2.body : isRequest(input) && input.body !== null ? clone(input) : null; Body.call(this, inputBody, { timeout: init2.timeout || input.timeout || 0, size: init2.size || input.size || 0 }); const headers = new Headers(init2.headers || input.headers || {}); if (inputBody != null && !headers.has("Content-Type")) { const contentType = extractContentType(inputBody); if (contentType) { headers.append("Content-Type", contentType); } } let signal = isRequest(input) ? input.signal : null; if ("signal" in init2) signal = init2.signal; if (signal != null && !isAbortSignal(signal)) { throw new TypeError("Expected signal to be an instanceof AbortSignal"); } this[INTERNALS$2] = { method, redirect: init2.redirect || input.redirect || "follow", headers, parsedURL, signal }; this.follow = init2.follow !== void 0 ? init2.follow : input.follow !== void 0 ? input.follow : 20; this.compress = init2.compress !== void 0 ? init2.compress : input.compress !== void 0 ? input.compress : true; this.counter = init2.counter || input.counter || 0; this.agent = init2.agent || input.agent; } get method() { return this[INTERNALS$2].method; } get url() { return format_url(this[INTERNALS$2].parsedURL); } get headers() { return this[INTERNALS$2].headers; } get redirect() { return this[INTERNALS$2].redirect; } get signal() { return this[INTERNALS$2].signal; } clone() { return new Request(this); } } Body.mixIn(Request.prototype); Object.defineProperty(Request.prototype, Symbol.toStringTag, { value: "Request", writable: false, enumerable: false, configurable: true }); Object.defineProperties(Request.prototype, { method: {enumerable: true}, url: {enumerable: true}, headers: {enumerable: true}, redirect: {enumerable: true}, clone: {enumerable: true}, signal: {enumerable: true} }); function getNodeRequestOptions(request) { const parsedURL = request[INTERNALS$2].parsedURL; const headers = new Headers(request[INTERNALS$2].headers); if (!headers.has("Accept")) { headers.set("Accept", "*/*"); } if (!parsedURL.protocol || !parsedURL.hostname) { throw new TypeError("Only absolute URLs are supported"); } if (!/^https?:$/.test(parsedURL.protocol)) { throw new TypeError("Only HTTP(S) protocols are supported"); } if (request.signal && request.body instanceof stream.default.Readable && !streamDestructionSupported) { throw new Error("Cancellation of streamed requests with AbortSignal is not supported in node < 8"); } let contentLengthValue = null; if (request.body == null && /^(POST|PUT)$/i.test(request.method)) { contentLengthValue = "0"; } if (request.body != null) { const totalBytes = getTotalBytes(request); if (typeof totalBytes === "number") { contentLengthValue = String(totalBytes); } } if (contentLengthValue) { headers.set("Content-Length", contentLengthValue); } if (!headers.has("User-Agent")) { headers.set("User-Agent", "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)"); } if (request.compress && !headers.has("Accept-Encoding")) { headers.set("Accept-Encoding", "gzip,deflate"); } let agent = request.agent; if (typeof agent === "function") { agent = agent(parsedURL); } if (!headers.has("Connection") && !agent) { headers.set("Connection", "close"); } return Object.assign({}, parsedURL, { method: request.method, headers: exportNodeCompatibleHeaders(headers), agent }); } function AbortError(message) { Error.call(this, message); this.type = "aborted"; this.message = message; Error.captureStackTrace(this, this.constructor); } AbortError.prototype = Object.create(Error.prototype); AbortError.prototype.constructor = AbortError; AbortError.prototype.name = "AbortError"; const PassThrough$1 = stream.default.PassThrough; const resolve_url = url.default.resolve; function fetch2(url2, opts) { if (!fetch2.Promise) { throw new Error("native promise missing, set fetch.Promise to your favorite alternative"); } Body.Promise = fetch2.Promise; return new fetch2.Promise(function(resolve, reject) { const request = new Request(url2, opts); const options = getNodeRequestOptions(request); const send = (options.protocol === "https:" ? https2.default : http2.default).request; const signal = request.signal; let response = null; const abort = function abort2() { let error = new AbortError("The user aborted a request."); reject(error); if (request.body && request.body instanceof stream.default.Readable) { request.body.destroy(error); } if (!response || !response.body) return; response.body.emit("error", error); }; if (signal && signal.aborted) { abort(); return; } const abortAndFinalize = function abortAndFinalize2() { abort(); finalize(); }; const req = send(options); let reqTimeout; if (signal) { signal.addEventListener("abort", abortAndFinalize); } function finalize() { req.abort(); if (signal) signal.removeEventListener("abort", abortAndFinalize); clearTimeout(reqTimeout); } if (request.timeout) { req.once("socket", function(socket) { reqTimeout = setTimeout(function() { reject(new FetchError(`network timeout at: ${request.url}`, "request-timeout")); finalize(); }, request.timeout); }); } req.on("error", function(err) { reject(new FetchError(`request to ${request.url} failed, reason: ${err.message}`, "system", err)); finalize(); }); req.on("response", function(res) { clearTimeout(reqTimeout); const headers = createHeadersLenient(res.headers); if (fetch2.isRedirect(res.statusCode)) { const location = headers.get("Location"); const locationURL = location === null ? null : resolve_url(request.url, location); switch (request.redirect) { case "error": reject(new FetchError(`uri requested responds with a redirect, redirect mode is set to error: ${request.url}`, "no-redirect")); finalize(); return; case "manual": if (locationURL !== null) { try { headers.set("Location", locationURL); } catch (err) { reject(err); } } break; case "follow": if (locationURL === null) { break; } if (request.counter >= request.follow) { reject(new FetchError(`maximum redirect reached at: ${request.url}`, "max-redirect")); finalize(); return; } const requestOpts = { headers: new Headers(request.headers), follow: request.follow, counter: request.counter + 1, agent: request.agent, compress: request.compress, method: request.method, body: request.body, signal: request.signal, timeout: request.timeout, size: request.size }; if (res.statusCode !== 303 && request.body && getTotalBytes(request) === null) { reject(new FetchError("Cannot follow redirect with body being a readable stream", "unsupported-redirect")); finalize(); return; } if (res.statusCode === 303 || (res.statusCode === 301 || res.statusCode === 302) && request.method === "POST") { requestOpts.method = "GET"; requestOpts.body = void 0; requestOpts.headers.delete("content-length"); } resolve(fetch2(new Request(locationURL, requestOpts))); finalize(); return; } } res.once("end", function() { if (signal) signal.removeEventListener("abort", abortAndFinalize); }); let body2 = res.pipe(new PassThrough$1()); const response_options = { url: request.url, status: res.statusCode, statusText: res.statusMessage, headers, size: request.size, timeout: request.timeout, counter: request.counter }; const codings = headers.get("Content-Encoding"); if (!request.compress || request.method === "HEAD" || codings === null || res.statusCode === 204 || res.statusCode === 304) { response = new Response(body2, response_options); resolve(response); return; } const zlibOptions = { flush: zlib2.default.Z_SYNC_FLUSH, finishFlush: zlib2.default.Z_SYNC_FLUSH }; if (codings == "gzip" || codings == "x-gzip") { body2 = body2.pipe(zlib2.default.createGunzip(zlibOptions)); response = new Response(body2, response_options); resolve(response); return; } if (codings == "deflate" || codings == "x-deflate") { const raw = res.pipe(new PassThrough$1()); raw.once("data", function(chunk) { if ((chunk[0] & 15) === 8) { body2 = body2.pipe(zlib2.default.createInflate()); } else { body2 = body2.pipe(zlib2.default.createInflateRaw()); } response = new Response(body2, response_options); resolve(response); }); return; } if (codings == "br" && typeof zlib2.default.createBrotliDecompress === "function") { body2 = body2.pipe(zlib2.default.createBrotliDecompress()); response = new Response(body2, response_options); resolve(response); return; } response = new Response(body2, response_options); resolve(response); }); writeToStream(req, request); }); } fetch2.isRedirect = function(code) { return code === 301 || code === 302 || code === 303 || code === 307 || code === 308; }; fetch2.Promise = global.Promise; var lib_default = fetch2; }); // node_modules/@tensorflow/tfjs/dist/tf.es2017.js var require_tf_es2017 = __commonJS((exports2, module2) => { /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ (function(global2, factory) { typeof exports2 === "object" && typeof module2 !== "undefined" ? factory(exports2) : typeof define === "function" && define.amd ? define(["exports"], factory) : (global2 = global2 || self, factory(global2.tf = global2.tf || {})); })(exports2, function(exports3) { "use strict"; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const EPSILON_FLOAT32 = 1e-7; const EPSILON_FLOAT16 = 1e-4; class DataStorage2 { constructor(backend2, dataMover) { this.backend = backend2; this.dataMover = dataMover; this.data = new WeakMap(); this.dataIdsCount = 0; } get(dataId) { if (!this.data.has(dataId)) { this.dataMover.moveData(this.backend, dataId); } return this.data.get(dataId); } set(dataId, value) { this.dataIdsCount++; this.data.set(dataId, value); } has(dataId) { return this.data.has(dataId); } delete(dataId) { this.dataIdsCount--; return this.data.delete(dataId); } numDataIds() { return this.dataIdsCount; } } class KernelBackend2 { time(f) { return notYetImplemented("time"); } read(dataId) { return notYetImplemented("read"); } readSync(dataId) { return notYetImplemented("readSync"); } numDataIds() { return notYetImplemented("numDataIds"); } disposeData(dataId) { return notYetImplemented("disposeData"); } write(values, shape, dtype) { return notYetImplemented("write"); } move(dataId, values, shape, dtype) { return notYetImplemented("move"); } memory() { return notYetImplemented("memory"); } floatPrecision() { return notYetImplemented("floatPrecision"); } epsilon() { return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; } batchMatMul(a, b, transposeA, transposeB) { return notYetImplemented("batchMatMul"); } fusedBatchMatMul({a, b, transposeA, transposeB, bias, activation: activation2, preluActivationWeights}) { return notYetImplemented("fusedBatchMatMul"); } slice(x, begin, size) { return notYetImplemented("slice"); } stridedSlice(x, begin, end, strides) { return notYetImplemented("stridedSlice"); } unstack(x, axis) { return notYetImplemented("unstack"); } reverse(a, axis) { return notYetImplemented("reverse"); } concat(tensors, axis) { return notYetImplemented("concat"); } neg(a) { return notYetImplemented("neg"); } add(a, b) { return notYetImplemented("add"); } addN(tensors) { return notYetImplemented("addN"); } subtract(a, b) { return notYetImplemented("subtract"); } multiply(a, b) { return notYetImplemented("multiply"); } realDivide(a, b) { return notYetImplemented("realDivide"); } floorDiv(a, b) { return notYetImplemented("floorDiv"); } sum(x, axes) { return notYetImplemented("sum"); } prod(x, axes) { return notYetImplemented("prod"); } unsortedSegmentSum(x, segmentIds, numSegments) { return notYetImplemented("unsortedSegmentSum"); } argMin(x, axis) { return notYetImplemented("argMin"); } argMax(x, axis) { return notYetImplemented("argMax"); } equal(a, b) { return notYetImplemented("equal"); } notEqual(a, b) { return notYetImplemented("notEqual"); } less(a, b) { return notYetImplemented("less"); } lessEqual(a, b) { return notYetImplemented("lessEqual"); } greater(a, b) { return notYetImplemented("greater"); } greaterEqual(a, b) { return notYetImplemented("greaterEqual"); } logicalNot(a) { return notYetImplemented("logicalNot"); } logicalAnd(a, b) { return notYetImplemented("logicalAnd"); } logicalOr(a, b) { return notYetImplemented("logicalOr"); } where(condition) { return notYetImplemented("where"); } select(condition, a, b) { return notYetImplemented("select"); } topk(x, k, sorted) { return notYetImplemented("topk"); } min(x, axes) { return notYetImplemented("min"); } minimum(a, b) { return notYetImplemented("minimum"); } mod(a, b) { return notYetImplemented("mod"); } max(x, axes) { return notYetImplemented("max"); } maximum(a, b) { return notYetImplemented("maximum"); } all(x, axes) { return notYetImplemented("all"); } any(x, axes) { return notYetImplemented("any"); } squaredDifference(a, b) { return notYetImplemented("squaredDifference"); } ceil(x) { return notYetImplemented("ceil"); } floor(x) { return notYetImplemented("floor"); } round(x) { return notYetImplemented("round"); } sign(x) { return notYetImplemented("sign"); } isNaN(x) { return notYetImplemented("isNaN"); } isInf(x) { return notYetImplemented("isInf"); } isFinite(x) { return notYetImplemented("isFinite"); } pow(a, b) { return notYetImplemented("pow"); } exp(x) { return notYetImplemented("exp"); } expm1(x) { return notYetImplemented("expm1"); } softmax(x, dim) { return notYetImplemented("softmax"); } log(x) { return notYetImplemented("log"); } log1p(x) { return notYetImplemented("log1p"); } sqrt(x) { return notYetImplemented("sqrt"); } rsqrt(x) { return notYetImplemented("rsqrt"); } square(x) { return notYetImplemented("square"); } reciprocal(x) { return notYetImplemented("reciprocal"); } relu(x) { return notYetImplemented("relu"); } relu6(x) { return notYetImplemented("relu6"); } prelu(x, a) { return notYetImplemented("prelu"); } elu(x) { return notYetImplemented("elu"); } eluDer(dy, y) { return notYetImplemented("eluDer"); } selu(x) { return notYetImplemented("selu"); } int(x) { return notYetImplemented("int"); } clip(x, min3, max3) { return notYetImplemented("clip"); } abs(x) { return notYetImplemented("abs"); } complexAbs(x) { return notYetImplemented("complexAbs"); } sigmoid(x) { return notYetImplemented("sigmoid"); } softplus(x) { return notYetImplemented("softplus"); } sin(x) { return notYetImplemented("sin"); } cos(x) { return notYetImplemented("cos"); } tan(x) { return notYetImplemented("tan"); } asin(x) { return notYetImplemented("asin"); } acos(x) { return notYetImplemented("acos"); } atan(x) { return notYetImplemented("atan"); } atan2(a, b) { return notYetImplemented("atan2"); } sinh(x) { return notYetImplemented("sinh"); } cosh(x) { return notYetImplemented("cosh"); } tanh(x) { return notYetImplemented("tanh"); } asinh(x) { return notYetImplemented("asinh"); } acosh(x) { return notYetImplemented("acosh"); } atanh(x) { return notYetImplemented("atanh"); } erf(x) { return notYetImplemented("erf"); } step(x, alpha) { return notYetImplemented("step"); } fusedConv2d({input: input2, filter, convInfo, bias, activation: activation2, preluActivationWeights}) { return notYetImplemented("fusedConv2d"); } conv2d(x, filter, convInfo) { return notYetImplemented("conv2d"); } conv2dDerInput(dy, filter, convInfo) { return notYetImplemented("conv2dDerInput"); } conv2dDerFilter(x, dY, convInfo) { return notYetImplemented("conv2dDerFilter"); } fusedDepthwiseConv2D({input: input2, filter, convInfo, bias, activation: activation2, preluActivationWeights}) { return notYetImplemented("fusedDepthwiseConv2D"); } depthwiseConv2D(input2, filter, convInfo) { return notYetImplemented("depthwiseConv2D"); } depthwiseConv2DDerInput(dy, filter, convInfo) { return notYetImplemented("depthwiseConv2DDerInput"); } depthwiseConv2DDerFilter(x, dY, convInfo) { return notYetImplemented("depthwiseConv2DDerFilter"); } conv3d(x, filter, convInfo) { return notYetImplemented("conv3d"); } conv3dDerInput(dy, filter, convInfo) { return notYetImplemented("conv3dDerInput"); } conv3dDerFilter(x, dY, convInfo) { return notYetImplemented("conv3dDerFilter"); } maxPool(x, convInfo) { return notYetImplemented("maxPool"); } maxPoolBackprop(dy, x, y, convInfo) { return notYetImplemented("maxPoolBackprop"); } avgPool(x, convInfo) { return notYetImplemented("avgPool"); } avgPoolBackprop(dy, x, convInfo) { return notYetImplemented("avgPoolBackprop"); } avgPool3d(x, convInfo) { return notYetImplemented("avgPool3d"); } avgPool3dBackprop(dy, x, convInfo) { return notYetImplemented("avgPool3dBackprop"); } maxPool3d(x, convInfo) { return notYetImplemented("maxPool3d"); } maxPool3dBackprop(dy, x, y, convInfo) { return notYetImplemented("maxPool3dBackprop"); } reshape(x, shape) { return notYetImplemented("reshape"); } cast(x, dtype) { return notYetImplemented("cast"); } tile(x, reps) { return notYetImplemented("tile"); } pad(x, paddings, constantValue) { return notYetImplemented("pad"); } transpose(x, perm) { return notYetImplemented("transpose"); } gather(x, indices, axis) { return notYetImplemented("gather"); } gatherND(x, indices) { return notYetImplemented("gatherND"); } scatterND(indices, updates, shape) { return notYetImplemented("scatterND"); } batchToSpaceND(x, blockShape, crops) { return notYetImplemented("batchToSpaceND"); } spaceToBatchND(x, blockShape, paddings) { return notYetImplemented("spaceToBatchND"); } resizeBilinear(x, newHeight, newWidth, alignCorners) { return notYetImplemented("resizeBilinear"); } resizeBilinearBackprop(dy, x, alignCorners) { return notYetImplemented("resizeBilinearBackprop"); } resizeNearestNeighbor(x, newHEight, newWidth, alignCorners) { return notYetImplemented("resizeNearestNeighbor"); } resizeNearestNeighborBackprop(dy, x, alignCorners) { return notYetImplemented("resizeNearestNeighborBackprop"); } batchNorm(x, mean2, variance2, offset, scale2, varianceEpsilon) { return notYetImplemented("batchNorm"); } localResponseNormalization4D(x, radius, bias, alpha, beta) { return notYetImplemented("localResponseNormalization4D"); } LRNGrad(dy, inputImage, outputImage, radius, bias, alpha, beta) { return notYetImplemented("LRNGrad"); } multinomial(logits, normalized, numSamples, seed) { return notYetImplemented("multinomial"); } oneHot(indices, depth, onValue, offValue) { return notYetImplemented("oneHot"); } cumsum(x, axis, exclusive, reverse3) { return notYetImplemented("cumsum"); } nonMaxSuppression(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { return notYetImplemented("nonMaxSuppression"); } fft(x) { return notYetImplemented("fft"); } ifft(x) { return notYetImplemented("ifft"); } complex(real2, imag2) { return notYetImplemented("complex"); } real(input2) { return notYetImplemented("real"); } imag(input2) { return notYetImplemented("imag"); } cropAndResize(image3, boxes, boxIndex, cropSize, method, extrapolationValue) { return notYetImplemented("cropAndResize"); } depthToSpace(x, blockSize, dataFormat) { return notYetImplemented("depthToSpace"); } split(value, sizeSplits, axis) { return notYetImplemented("split"); } sparseToDense(sparseIndices, sparseValues, outputShape, defaultValue) { return notYetImplemented("sparseToDense"); } diag(x) { return notYetImplemented("diag"); } fill(shape, value, dtype) { return notYetImplemented("fill"); } onesLike(x) { return notYetImplemented("onesLike"); } zerosLike(x) { return notYetImplemented("zerosLike"); } linspace(start, stop, num) { return notYetImplemented("linspace"); } dispose() { return notYetImplemented("dispose"); } } function notYetImplemented(kernelName) { throw new Error(`'${kernelName}' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen`); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function shuffle(array2) { let counter = array2.length; let temp = 0; let index2 = 0; while (counter > 0) { index2 = Math.random() * counter | 0; counter--; temp = array2[counter]; array2[counter] = array2[index2]; array2[index2] = temp; } } function clamp(min3, x, max3) { return Math.max(min3, Math.min(x, max3)); } function nearestLargerEven(val) { return val % 2 === 0 ? val : val + 1; } function sum2(arr) { let sum3 = 0; for (let i = 0; i < arr.length; i++) { sum3 += arr[i]; } return sum3; } function randUniform(a, b) { const r = Math.random(); return b * r + (1 - r) * a; } function distSquared(a, b) { let result = 0; for (let i = 0; i < a.length; i++) { const diff = Number(a[i]) - Number(b[i]); result += diff * diff; } return result; } function assert(expr, msg) { if (!expr) { throw new Error(typeof msg === "string" ? msg : msg()); } } function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = "") { assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); } function assertNonNull(a) { assert(a != null, () => `The input to the tensor constructor must be a non-null value.`); } function flatten(arr, result = [], skipTypedArray = false) { if (result == null) { result = []; } if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { for (let i = 0; i < arr.length; ++i) { flatten(arr[i], result, skipTypedArray); } } else { result.push(arr); } return result; } function sizeFromShape(shape) { if (shape.length === 0) { return 1; } let size = shape[0]; for (let i = 1; i < shape.length; i++) { size *= shape[i]; } return size; } function isScalarShape(shape) { return shape.length === 0; } function arraysEqual(n1, n2) { if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } if (n1.length !== n2.length) { return false; } for (let i = 0; i < n1.length; i++) { if (n1[i] !== n2[i]) { return false; } } return true; } function isInt(a) { return a % 1 === 0; } function tanh(x) { if (Math.tanh != null) { return Math.tanh(x); } if (x === Infinity) { return 1; } else if (x === -Infinity) { return -1; } else { const e2x = Math.exp(2 * x); return (e2x - 1) / (e2x + 1); } } function sizeToSquarishShape(size) { const width = Math.ceil(Math.sqrt(size)); return [width, Math.ceil(size / width)]; } function createShuffledIndices(n) { const shuffledIndices = new Uint32Array(n); for (let i = 0; i < n; ++i) { shuffledIndices[i] = i; } shuffle(shuffledIndices); return shuffledIndices; } function rightPad(a, size) { if (size <= a.length) { return a; } return a + " ".repeat(size - a.length); } function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter) { return new Promise((resolve, reject) => { let tryCount = 0; const tryFn = () => { if (checkFn()) { resolve(); return; } tryCount++; const nextBackoff = delayFn(tryCount); if (maxCounter != null && tryCount >= maxCounter) { reject(); return; } setTimeout(tryFn, nextBackoff); }; tryFn(); }); } function inferFromImplicitShape(shape, size) { let shapeProd = 1; let implicitIdx = -1; for (let i = 0; i < shape.length; ++i) { if (shape[i] >= 0) { shapeProd *= shape[i]; } else if (shape[i] === -1) { if (implicitIdx !== -1) { throw Error(`Shapes can only have 1 implicit size. Found -1 at dim ${implicitIdx} and dim ${i}`); } implicitIdx = i; } else if (shape[i] < 0) { throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`); } } if (implicitIdx === -1) { if (size > 0 && size !== shapeProd) { throw Error(`Size(${size}) must match the product of shape ${shape}`); } return shape; } if (shapeProd === 0) { throw Error(`Cannot infer the missing size in [${shape}] when there are 0 elements`); } if (size % shapeProd !== 0) { throw Error(`The implicit shape can't be a fractional number. Got ${size} / ${shapeProd}`); } const newShape = shape.slice(); newShape[implicitIdx] = size / shapeProd; return newShape; } function parseAxisParam(axis, shape) { const rank = shape.length; axis = axis == null ? shape.map((s, i) => i) : [].concat(axis); assert(axis.every((ax) => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but got axis ${axis}`); assert(axis.every((ax) => isInt(ax)), () => `All values in axis param must be integers but got axis ${axis}`); return axis.map((a) => a < 0 ? rank + a : a); } function squeezeShape(shape, axis) { const newShape = []; const keptDims = []; const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; const axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort(); let j = 0; for (let i = 0; i < shape.length; ++i) { if (axes != null) { if (axes[j] === i && shape[i] !== 1) { throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); } if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { newShape.push(shape[i]); keptDims.push(i); } if (axes[j] <= i) { j++; } } if (shape[i] !== 1) { newShape.push(shape[i]); keptDims.push(i); } } return {newShape, keptDims}; } function getTypedArrayFromDType(dtype, size) { let values = null; if (dtype == null || dtype === "float32") { values = new Float32Array(size); } else if (dtype === "int32") { values = new Int32Array(size); } else if (dtype === "bool") { values = new Uint8Array(size); } else { throw new Error(`Unknown data type ${dtype}`); } return values; } function getArrayFromDType(dtype, size) { let values = null; if (dtype == null || dtype === "float32") { values = new Float32Array(size); } else if (dtype === "int32") { values = new Int32Array(size); } else if (dtype === "bool") { values = new Uint8Array(size); } else if (dtype === "string") { values = new Array(size); } else { throw new Error(`Unknown data type ${dtype}`); } return values; } function checkConversionForErrors(vals, dtype) { for (let i = 0; i < vals.length; i++) { const num = vals[i]; if (isNaN(num) || !isFinite(num)) { throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`); } } } function isValidDtype(dtype) { return dtype === "bool" || dtype === "complex64" || dtype === "float32" || dtype === "int32" || dtype === "string"; } function hasEncodingLoss(oldType, newType) { if (newType === "complex64") { return false; } if (newType === "float32" && oldType !== "complex64") { return false; } if (newType === "int32" && oldType !== "float32" && oldType !== "complex64") { return false; } if (newType === "bool" && oldType === "bool") { return false; } return true; } function isTypedArray(a) { return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array; } function bytesPerElement(dtype) { if (dtype === "float32" || dtype === "int32") { return 4; } else if (dtype === "complex64") { return 8; } else if (dtype === "bool") { return 1; } else { throw new Error(`Unknown dtype ${dtype}`); } } function bytesFromStringArray(arr) { if (arr == null) { return 0; } let bytes = 0; arr.forEach((x) => bytes += x.length); return bytes; } function isString(value) { return typeof value === "string" || value instanceof String; } function isBoolean(value) { return typeof value === "boolean"; } function isNumber(value) { return typeof value === "number"; } function inferDtype(values) { if (Array.isArray(values)) { return inferDtype(values[0]); } if (values instanceof Float32Array) { return "float32"; } else if (values instanceof Int32Array || values instanceof Uint8Array) { return "int32"; } else if (isNumber(values)) { return "float32"; } else if (isString(values)) { return "string"; } else if (isBoolean(values)) { return "bool"; } return "float32"; } function isFunction(f) { return !!(f && f.constructor && f.call && f.apply); } function nearestDivisor(size, start) { for (let i = start; i < size; ++i) { if (size % i === 0) { return i; } } return size; } function computeStrides(shape) { const rank = shape.length; if (rank < 2) { return []; } const strides = new Array(rank - 1); strides[rank - 2] = shape[rank - 1]; for (let i = rank - 3; i >= 0; --i) { strides[i] = strides[i + 1] * shape[i + 1]; } return strides; } function createNestedArray(offset, shape, a) { const ret = new Array(); if (shape.length === 1) { const d = shape[0]; for (let i = 0; i < d; i++) { ret[i] = a[offset + i]; } } else { const d = shape[0]; const rest = shape.slice(1); const len = rest.reduce((acc, c) => acc * c); for (let i = 0; i < d; i++) { ret[i] = createNestedArray(offset + i * len, rest, a); } } return ret; } function toNestedArray(shape, a) { if (shape.length === 0) { return a[0]; } const size = shape.reduce((acc, c) => acc * c); if (size === 0) { return []; } if (size !== a.length) { throw new Error(`[${shape}] does not match the input size ${a.length}.`); } return createNestedArray(0, shape, a); } function makeOnesTypedArray(size, dtype) { const array2 = makeZerosTypedArray(size, dtype); for (let i = 0; i < array2.length; i++) { array2[i] = 1; } return array2; } function makeZerosTypedArray(size, dtype) { if (dtype == null || dtype === "float32" || dtype === "complex64") { return new Float32Array(size); } else if (dtype === "int32") { return new Int32Array(size); } else if (dtype === "bool") { return new Uint8Array(size); } else { throw new Error(`Unknown data type ${dtype}`); } } function makeZerosNestedTypedArray(shape, dtype) { const size = shape.reduce((prev, curr) => prev * curr, 1); if (dtype == null || dtype === "float32") { return toNestedArray(shape, new Float32Array(size)); } else if (dtype === "int32") { return toNestedArray(shape, new Int32Array(size)); } else if (dtype === "bool") { return toNestedArray(shape, new Uint8Array(size)); } else { throw new Error(`Unknown data type ${dtype}`); } } function assertNonNegativeIntegerDimensions(shape) { shape.forEach((dimSize) => { assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got shape [${shape}].`); }); } function locToIndex(locs, rank, strides) { if (rank === 0) { return 0; } else if (rank === 1) { return locs[0]; } let index2 = locs[locs.length - 1]; for (let i = 0; i < locs.length - 1; ++i) { index2 += strides[i] * locs[i]; } return index2; } function indexToLoc(index2, rank, strides) { if (rank === 0) { return []; } else if (rank === 1) { return [index2]; } const locs = new Array(rank); for (let i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index2 / strides[i]); index2 -= locs[i] * strides[i]; } locs[locs.length - 1] = index2; return locs; } function isPromise(object) { return object && object.then && typeof object.then === "function"; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const TENSORFLOWJS_FLAGS_PREFIX = "tfjsflags"; class Environment { constructor(global2) { this.global = global2; this.flags = {}; this.flagRegistry = {}; this.urlFlags = {}; this.populateURLFlags(); } setPlatform(platformName, platform) { if (this.platform != null) { console.warn(`Platform ${this.platformName} has already been set. Overwriting the platform with ${platform}.`); } this.platformName = platformName; this.platform = platform; } registerFlag(flagName, evaluationFn, setHook) { this.flagRegistry[flagName] = {evaluationFn, setHook}; if (this.urlFlags[flagName] != null) { const flagValue = this.urlFlags[flagName]; console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`); this.set(flagName, flagValue); } } async getAsync(flagName) { if (flagName in this.flags) { return this.flags[flagName]; } this.flags[flagName] = await this.evaluateFlag(flagName); return this.flags[flagName]; } get(flagName) { if (flagName in this.flags) { return this.flags[flagName]; } const flagValue = this.evaluateFlag(flagName); if (isPromise(flagValue)) { throw new Error(`Flag ${flagName} cannot be synchronously evaluated. Please use getAsync() instead.`); } this.flags[flagName] = flagValue; return this.flags[flagName]; } getNumber(flagName) { return this.get(flagName); } getBool(flagName) { return this.get(flagName); } getFlags() { return this.flags; } get features() { return this.flags; } set(flagName, value) { if (this.flagRegistry[flagName] == null) { throw new Error(`Cannot set flag ${flagName} as it has not been registered.`); } this.flags[flagName] = value; if (this.flagRegistry[flagName].setHook != null) { this.flagRegistry[flagName].setHook(value); } } evaluateFlag(flagName) { if (this.flagRegistry[flagName] == null) { throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`); } return this.flagRegistry[flagName].evaluationFn(); } setFlags(flags) { this.flags = Object.assign({}, flags); } reset() { this.flags = {}; this.urlFlags = {}; this.populateURLFlags(); } populateURLFlags() { if (typeof this.global === "undefined" || typeof this.global.location === "undefined" || typeof this.global.location.search === "undefined") { return; } const urlParams = getQueryParams(this.global.location.search); if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(","); keyValues.forEach((keyValue) => { const [key, value] = keyValue.split(":"); this.urlFlags[key] = parseValue(key, value); }); } } } function getQueryParams(queryString) { const params = {}; queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => { decodeParam(params, t[0], t[1]); return t.join("="); }); return params; } function decodeParam(params, name, value) { params[decodeURIComponent(name)] = decodeURIComponent(value || ""); } function parseValue(flagName, value) { value = value.toLowerCase(); if (value === "true" || value === "false") { return value === "true"; } else if (`${+value}` === value) { return +value; } throw new Error(`Could not parse value flag value ${value} for flag ${flagName}.`); } function env3() { return exports3.ENV; } exports3.ENV = null; function setEnvironmentGlobal(environment) { exports3.ENV = environment; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let globalNameSpace; function getGlobalNamespace() { if (globalNameSpace == null) { let ns; if (typeof window !== "undefined") { ns = window; } else if (typeof global !== "undefined") { ns = global; } else if (typeof process !== "undefined") { ns = process; } else if (typeof self !== "undefined") { ns = self; } else { throw new Error("Could not find a global object"); } globalNameSpace = ns; } return globalNameSpace; } function getGlobalMap() { const ns = getGlobalNamespace(); if (ns._tfGlobals == null) { ns._tfGlobals = new Map(); } return ns._tfGlobals; } function getGlobal(key, init2) { const globalMap = getGlobalMap(); if (globalMap.has(key)) { return globalMap.get(key); } else { const singleton = init2(); globalMap.set(key, singleton); return globalMap.get(key); } } const Abs3 = "Abs"; const Acos = "Acos"; const Acosh = "Acosh"; const Add3 = "Add"; const AddN3 = "AddN"; const All = "All"; const Any = "Any"; const ArgMax3 = "ArgMax"; const ArgMin = "ArgMin"; const Asin = "Asin"; const Asinh = "Asinh"; const Atan = "Atan"; const Atanh = "Atanh"; const Atan2 = "Atan2"; const AvgPool3 = "AvgPool"; const AvgPoolBackprop = "AvgPoolBackprop"; const AvgPool3D = "AvgPool3D"; const AvgPool3DBackprop = "AvgPool3DBackprop"; const BatchMatMul3 = "BatchMatMul"; const BatchToSpaceND = "BatchToSpaceND"; const BroadcastTo = "BroadcastTo"; const Cast5 = "Cast"; const Ceil = "Ceil"; const ClipByValue3 = "ClipByValue"; const Complex = "Complex"; const Concat3 = "Concat"; const Conv2D3 = "Conv2D"; const Conv2DBackpropFilter = "Conv2DBackpropFilter"; const Conv2DBackpropInput3 = "Conv2DBackpropInput"; const Conv3D = "Conv3D"; const Conv3DBackpropFilterV2 = "Conv3DBackpropFilterV2"; const Conv3DBackpropInputV2 = "Conv3DBackpropInputV2"; const Cos3 = "Cos"; const Cosh = "Cosh"; const Cumsum3 = "Cumsum"; const CropAndResize3 = "CropAndResize"; const DepthToSpace3 = "DepthToSpace"; const DepthwiseConv2dNative3 = "DepthwiseConv2dNative"; const DepthwiseConv2dNativeBackpropFilter = "DepthwiseConv2dNativeBackpropFilter"; const DepthwiseConv2dNativeBackpropInput = "DepthwiseConv2dNativeBackpropInput"; const Diag = "Diag"; const Dilation2D = "Dilation2D"; const Dilation2DBackpropInput = "Dilation2DBackpropInput"; const Dilation2DBackpropFilter = "Dilation2DBackpropFilter"; const Div3 = "Div"; const Elu = "Elu"; const EluGrad = "EluGrad"; const Erf = "Erf"; const Equal3 = "Equal"; const Exp3 = "Exp"; const Expm1 = "Expm1"; const FFT = "FFT"; const Fill3 = "Fill"; const FlipLeftRight3 = "FlipLeftRight"; const Floor = "Floor"; const FloorDiv3 = "FloorDiv"; const FusedBatchNorm3 = "FusedBatchNorm"; const GatherV23 = "GatherV2"; const GatherNd3 = "GatherNd"; const Greater3 = "Greater"; const GreaterEqual3 = "GreaterEqual"; const Identity5 = "Identity"; const IFFT = "IFFT"; const Imag = "Imag"; const IsFinite = "IsFinite"; const IsInf = "IsInf"; const IsNan = "IsNan"; const Less3 = "Less"; const LessEqual3 = "LessEqual"; const LinSpace = "LinSpace"; const Log3 = "Log"; const Log1p = "Log1p"; const LogicalAnd3 = "LogicalAnd"; const LogicalNot = "LogicalNot"; const LogicalOr = "LogicalOr"; const LogSoftmax = "LogSoftmax"; const LRN = "LRN"; const LRNBackprop = "LRNBackprop"; const Max3 = "Max"; const Maximum3 = "Maximum"; const MaxPool3 = "MaxPool"; const MaxPoolBackprop = "MaxPoolBackprop"; const MaxPool3D = "MaxPool3D"; const MaxPool3DBackprop = "MaxPool3DBackprop"; const MaxPoolWithArgmax = "MaxPoolWithArgmax"; const Mean = "Mean"; const Min3 = "Min"; const Minimum3 = "Minimum"; const MirrorPad = "MirrorPad"; const Mod = "Mod"; const Multiply3 = "Multiply"; const Negate3 = "Negate"; const NotEqual3 = "NotEqual"; const NonMaxSuppressionV33 = "NonMaxSuppressionV3"; const NonMaxSuppressionV43 = "NonMaxSuppressionV4"; const NonMaxSuppressionV53 = "NonMaxSuppressionV5"; const OnesLike3 = "OnesLike"; const OneHot3 = "OneHot"; const PadV23 = "PadV2"; const Pool = "Pool"; const Pow3 = "Pow"; const Prelu3 = "Prelu"; const Prod = "Prod"; const Range = "Range"; const Real = "Real"; const Reciprocal = "Reciprocal"; const Relu3 = "Relu"; const Reshape6 = "Reshape"; const ResizeNearestNeighbor = "ResizeNearestNeighbor"; const ResizeNearestNeighborGrad = "ResizeNearestNeighborGrad"; const ResizeBilinear3 = "ResizeBilinear"; const ResizeBilinearGrad = "ResizeBilinearGrad"; const Relu63 = "Relu6"; const Reverse3 = "Reverse"; const Round = "Round"; const Rsqrt3 = "Rsqrt"; const ScatterNd3 = "ScatterNd"; const SelectV23 = "SelectV2"; const Selu = "Selu"; const Slice6 = "Slice"; const Sin3 = "Sin"; const Sinh = "Sinh"; const Sign = "Sign"; const Sigmoid3 = "Sigmoid"; const Softplus = "Softplus"; const Sqrt3 = "Sqrt"; const Sum3 = "Sum"; const SpaceToBatchND = "SpaceToBatchND"; const SplitV2 = "SplitV"; const Softmax3 = "Softmax"; const SquaredDifference3 = "SquaredDifference"; const Square3 = "Square"; const Sub3 = "Sub"; const SparseToDense = "SparseToDense"; const StridedSlice3 = "StridedSlice"; const Tan = "Tan"; const Tanh3 = "Tanh"; const Tile3 = "Tile"; const TopK = "TopK"; const Transpose5 = "Transpose"; const Unique = "Unique"; const Unpack3 = "Unpack"; const UnsortedSegmentSum = "UnsortedSegmentSum"; const ZerosLike3 = "ZerosLike"; const Step = "Step"; const FromPixels = "FromPixels"; const RotateWithOffset3 = "RotateWithOffset"; const _FusedMatMul2 = "_FusedMatMul"; const FusedConv2D3 = "FusedConv2D"; const FusedDepthwiseConv2D3 = "FusedDepthwiseConv2D"; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const kernelRegistry = getGlobal("kernelRegistry", () => new Map()); const gradRegistry = getGlobal("gradRegistry", () => new Map()); function getKernel(kernelName, backendName) { const key = makeKey(kernelName, backendName); return kernelRegistry.get(key); } function getGradient(kernelName) { return gradRegistry.get(kernelName); } function getKernelsForBackend(backendName) { const it = kernelRegistry.entries(); const result = []; while (true) { const {done, value} = it.next(); if (done) { break; } const [key, config2] = value; const [backend2] = key.split("_"); if (backend2 === backendName) { result.push(config2); } } return result; } function registerKernel2(config2) { const {kernelName, backendName} = config2; const key = makeKey(kernelName, backendName); if (kernelRegistry.has(key)) { console.warn(`The kernel '${kernelName}' for backend '${backendName}' is already registered`); } kernelRegistry.set(key, config2); } function registerGradient(config2) { const {kernelName} = config2; if (gradRegistry.has(kernelName)) { if (env3().getBool("DEBUG")) { console.warn(`Overriding the gradient for '${kernelName}'`); } } gradRegistry.set(kernelName, config2); } function unregisterKernel(kernelName, backendName) { const key = makeKey(kernelName, backendName); if (!kernelRegistry.has(key)) { throw new Error(`The kernel '${kernelName}' for backend '${backendName}' is not registered`); } kernelRegistry.delete(key); } function unregisterGradient(kernelName) { if (!gradRegistry.has(kernelName)) { throw new Error(`The gradient '${kernelName}' for backend is not registered`); } gradRegistry.delete(kernelName); } function copyRegisteredKernels(registeredBackendName, newBackendName) { const kernels = getKernelsForBackend(registeredBackendName); kernels.forEach((kernelConfig) => { const newKernelConfig = Object.assign({}, kernelConfig, {backendName: newBackendName}); registerKernel2(newKernelConfig); }); } function makeKey(kernelName, backendName) { return `${backendName}_${kernelName}`; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function createScalarValue(value, dtype) { if (dtype === "string") { return encodeString(value); } return toTypedArray([value], dtype); } function noConversionNeeded(a, dtype) { return a instanceof Float32Array && dtype === "float32" || a instanceof Int32Array && dtype === "int32" || a instanceof Uint8Array && dtype === "bool"; } function toTypedArray(a, dtype) { if (dtype === "string") { throw new Error("Cannot convert a string[] to a TypedArray"); } if (Array.isArray(a)) { a = flatten(a); } if (env3().getBool("DEBUG")) { checkConversionForErrors(a, dtype); } if (noConversionNeeded(a, dtype)) { return a; } if (dtype == null || dtype === "float32" || dtype === "complex64") { return new Float32Array(a); } else if (dtype === "int32") { return new Int32Array(a); } else if (dtype === "bool") { const bool = new Uint8Array(a.length); for (let i = 0; i < bool.length; ++i) { if (Math.round(a[i]) !== 0) { bool[i] = 1; } } return bool; } else { throw new Error(`Unknown data type ${dtype}`); } } function now2() { return env3().platform.now(); } function fetch$1(path, requestInits) { return env3().platform.fetch(path, requestInits); } function encodeString(s, encoding = "utf-8") { encoding = encoding || "utf-8"; return env3().platform.encode(s, encoding); } function decodeString(bytes, encoding = "utf-8") { encoding = encoding || "utf-8"; return env3().platform.decode(bytes, encoding); } var util27 = /* @__PURE__ */ Object.freeze({ __proto__: null, createScalarValue, toTypedArray, now: now2, fetch: fetch$1, encodeString, decodeString, shuffle, clamp, nearestLargerEven, sum: sum2, randUniform, distSquared, assert, assertShapesMatch, assertNonNull, flatten, sizeFromShape, isScalarShape, arraysEqual, isInt, tanh, sizeToSquarishShape, createShuffledIndices, rightPad, repeatedTry, inferFromImplicitShape, parseAxisParam, squeezeShape, getTypedArrayFromDType, getArrayFromDType, checkConversionForErrors, isValidDtype, hasEncodingLoss, isTypedArray, bytesPerElement, bytesFromStringArray, isString, isBoolean, isNumber, inferDtype, isFunction, nearestDivisor, computeStrides, toNestedArray, makeOnesTypedArray, makeZerosTypedArray, makeZerosNestedTypedArray, assertNonNegativeIntegerDimensions, locToIndex, indexToLoc, isPromise }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class Profiler { constructor(backendTimer, logger) { this.backendTimer = backendTimer; this.logger = logger; if (logger == null) { this.logger = new Logger(); } } profileKernel(kernelName, inputs, f) { let outputs; const holdResultWrapperFn = () => { outputs = f(); }; const timer = this.backendTimer.time(holdResultWrapperFn); for (let i = 0; i < outputs.length; i++) { const output = outputs[i]; output.data().then((tensorVals) => { checkComputationForErrors(tensorVals, output.dtype, kernelName); }); } const kernelProfile = { kernelName, outputs, inputs, timeMs: timer.then((timing) => timing.kernelMs), extraInfo: timer.then((timing) => timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : "") }; return kernelProfile; } logKernelProfile(kernelProfile) { const {kernelName, outputs, timeMs, inputs, extraInfo} = kernelProfile; outputs.forEach((result) => { Promise.all([result.data(), timeMs, extraInfo]).then((valueContainer) => { this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]); }); }); } } function checkComputationForErrors(vals, dtype, kernelName) { if (dtype !== "float32") { return false; } for (let i = 0; i < vals.length; i++) { const num = vals[i]; if (isNaN(num) || !isFinite(num)) { console.warn(`Found ${num} in the result of '${kernelName}'`); return true; } } return false; } class Logger { logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) { const time2 = typeof timeMs === "number" ? rightPad(`${timeMs}ms`, 9) : timeMs["error"]; const paddedName = rightPad(name, 25); const rank = result.rank; const size = result.size; const shape = rightPad(result.shape.toString(), 14); let inputShapesDescription = ""; for (const name2 in inputs) { const input2 = inputs[name2]; if (input2 != null) { const inputShape = input2.shape || result.shape; const inputRank = inputShape.length; inputShapesDescription += `${name2}: ${inputRank}D ${inputRank > 0 ? inputShape : ""} `; } } console.log(`%c${paddedName} %c${time2} %c${rank}D ${shape} %c${size} %c${inputShapesDescription} %c${extraInfo}`, "font-weight:bold", "color:red", "color:blue", "color: orange", "color: green", "color: steelblue"); } } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function getFilteredNodesXToY(tape, xs, y) { const tensorsFromX = {}; const nodesFromX = {}; for (let i = 0; i < xs.length; i++) { tensorsFromX[xs[i].id] = true; } for (let i = 0; i < tape.length; i++) { const node = tape[i]; const nodeInputs = node.inputs; for (const inputName in nodeInputs) { const input2 = nodeInputs[inputName]; let anyInputFromX = false; for (let j = 0; j < xs.length; j++) { if (tensorsFromX[input2.id]) { node.outputs.forEach((output) => tensorsFromX[output.id] = true); anyInputFromX = true; nodesFromX[node.id] = true; break; } } if (anyInputFromX) { break; } } } const tensorsLeadToY = {}; tensorsLeadToY[y.id] = true; const nodesToY = {}; for (let i = tape.length - 1; i >= 0; i--) { const node = tape[i]; const nodeInputs = node.inputs; for (let j = 0; j < node.outputs.length; j++) { if (tensorsLeadToY[node.outputs[j].id]) { for (const inputName in nodeInputs) { tensorsLeadToY[nodeInputs[inputName].id] = true; nodesToY[node.id] = true; } break; } } } const filteredTape = []; for (let i = 0; i < tape.length; i++) { const node = tape[i]; if (nodesFromX[node.id] && nodesToY[node.id]) { const prunedInputs = {}; for (const inputName in node.inputs) { const nodeInput = node.inputs[inputName]; if (tensorsFromX[nodeInput.id]) { prunedInputs[inputName] = nodeInput; } } const prunedNode = Object.assign({}, node); prunedNode.inputs = prunedInputs; prunedNode.outputs = node.outputs; filteredTape.push(prunedNode); } } return filteredTape; } function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy2, add2) { for (let i = filteredTape.length - 1; i >= 0; i--) { const node = filteredTape[i]; const dys = []; node.outputs.forEach((o) => { const gradTensor = tensorAccumulatedGradientMap[o.id]; if (gradTensor != null) { dys.push(gradTensor); } else { dys.push(null); } }); if (node.gradient == null) { throw new Error(`Cannot compute gradient: gradient function not found for ${node.kernelName}.`); } const inputGradients = node.gradient(dys); for (const inputName in node.inputs) { if (!(inputName in inputGradients)) { throw new Error(`Cannot backprop through input ${inputName}. Available gradients found: ${Object.keys(inputGradients)}.`); } const dx = tidy2(() => inputGradients[inputName]()); if (dx.dtype !== "float32") { throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ${inputName} must have 'float32' dtype, but has '${dx.dtype}'`); } const x = node.inputs[inputName]; if (!arraysEqual(dx.shape, x.shape)) { throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input '${inputName}' has shape '${dx.shape}', which does not match the shape of the input '${x.shape}'`); } if (tensorAccumulatedGradientMap[x.id] == null) { tensorAccumulatedGradientMap[x.id] = dx; } else { const curGradient = tensorAccumulatedGradientMap[x.id]; tensorAccumulatedGradientMap[x.id] = add2(curGradient, dx); curGradient.dispose(); } } } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const FORMAT_LIMIT_NUM_VALS = 20; const FORMAT_NUM_FIRST_LAST_VALS = 3; const FORMAT_NUM_SIG_DIGITS = 7; function tensorToString(vals, shape, dtype, verbose) { const strides = computeStrides(shape); const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); const rank = shape.length; const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); const lines = ["Tensor"]; if (verbose) { lines.push(` dtype: ${dtype}`); lines.push(` rank: ${rank}`); lines.push(` shape: [${shape}]`); lines.push(` values:`); } lines.push(valsLines.map((l) => " " + l).join("\n")); return lines.join("\n"); } function computeMaxSizePerColumn(vals, shape, dtype, strides) { const n = sizeFromShape(shape); const numCols = strides[strides.length - 1]; const padPerCol = new Array(numCols).fill(0); const rank = shape.length; const valuesOrTuples = dtype === "complex64" ? createComplexTuples(vals) : vals; if (rank > 1) { for (let row = 0; row < n / numCols; row++) { const offset = row * numCols; for (let j = 0; j < numCols; j++) { padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); } } } return padPerCol; } function valToString(val, pad3, dtype) { let valStr; if (Array.isArray(val)) { valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`; } else if (isString(val)) { valStr = `'${val}'`; } else if (dtype === "bool") { valStr = boolNumToString(val); } else { valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); } return rightPad(valStr, pad3); } function boolNumToString(v) { return v === 0 ? "false" : "true"; } function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) { const storagePerElement = dtype === "complex64" ? 2 : 1; const size = shape[0]; const rank = shape.length; if (rank === 0) { if (dtype === "complex64") { const complexTuple = createComplexTuples(vals); return [valToString(complexTuple[0], 0, dtype)]; } if (dtype === "bool") { return [boolNumToString(vals[0])]; } return [vals[0].toString()]; } if (rank === 1) { if (size > FORMAT_LIMIT_NUM_VALS) { const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; let firstVals = Array.from(vals.slice(0, firstValsSize)); let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); if (dtype === "complex64") { firstVals = createComplexTuples(firstVals); lastVals = createComplexTuples(lastVals); } return [ "[" + firstVals.map((x, i) => valToString(x, padPerCol[i], dtype)).join(", ") + ", ..., " + lastVals.map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype)).join(", ") + "]" ]; } const displayVals = dtype === "complex64" ? createComplexTuples(vals) : Array.from(vals); return [ "[" + displayVals.map((x, i) => valToString(x, padPerCol[i], dtype)).join(", ") + "]" ]; } const subshape = shape.slice(1); const substrides = strides.slice(1); const stride = strides[0] * storagePerElement; const lines = []; if (size > FORMAT_LIMIT_NUM_VALS) { for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { const start = i * stride; const end = start + stride; lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false)); } lines.push("..."); for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { const start = i * stride; const end = start + stride; lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1)); } } else { for (let i = 0; i < size; i++) { const start = i * stride; const end = start + stride; lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1)); } } const sep = rank === 2 ? "," : ""; lines[0] = "[" + lines[0] + sep; for (let i = 1; i < lines.length - 1; i++) { lines[i] = " " + lines[i] + sep; } let newLineSep = ",\n"; for (let i = 2; i < rank; i++) { newLineSep += "\n"; } lines[lines.length - 1] = " " + lines[lines.length - 1] + "]" + (isLast ? "" : newLineSep); return lines; } function createComplexTuples(vals) { const complexTuples = []; for (let i = 0; i < vals.length; i += 2) { complexTuples.push([vals[i], vals[i + 1]]); } return complexTuples; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class TensorBuffer { constructor(shape, dtype, values) { this.dtype = dtype; this.shape = shape.slice(); this.size = sizeFromShape(shape); if (values != null) { const n = values.length; assert(n === this.size, () => `Length of values '${n}' does not match the size inferred by the shape '${this.size}'.`); } if (dtype === "complex64") { throw new Error(`complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).`); } this.values = values || getArrayFromDType(dtype, this.size); this.strides = computeStrides(shape); } set(value, ...locs) { if (locs.length === 0) { locs = [0]; } assert(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must match the rank (${this.rank})`); const index2 = this.locToIndex(locs); this.values[index2] = value; } get(...locs) { if (locs.length === 0) { locs = [0]; } let i = 0; for (const loc of locs) { if (loc < 0 || loc >= this.shape[i]) { const msg = `Requested out of range element at ${locs}. Buffer shape=${this.shape}`; throw new Error(msg); } i++; } let index2 = locs[locs.length - 1]; for (let i2 = 0; i2 < locs.length - 1; ++i2) { index2 += this.strides[i2] * locs[i2]; } return this.values[index2]; } locToIndex(locs) { if (this.rank === 0) { return 0; } else if (this.rank === 1) { return locs[0]; } let index2 = locs[locs.length - 1]; for (let i = 0; i < locs.length - 1; ++i) { index2 += this.strides[i] * locs[i]; } return index2; } indexToLoc(index2) { if (this.rank === 0) { return []; } else if (this.rank === 1) { return [index2]; } const locs = new Array(this.shape.length); for (let i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index2 / this.strides[i]); index2 -= locs[i] * this.strides[i]; } locs[locs.length - 1] = index2; return locs; } get rank() { return this.shape.length; } toTensor() { return trackerFn().makeTensor(this.values, this.shape, this.dtype); } } let trackerFn = null; let opHandler = null; let deprecationWarningFn = null; [deprecationWarningFn]; function setTensorTracker(fn) { trackerFn = fn; } function setOpHandler(handler) { opHandler = handler; } function setDeprecationWarningFn(fn) { deprecationWarningFn = fn; } class Tensor { constructor(shape, dtype, dataId, id) { this.kept = false; this.isDisposedInternal = false; this.shape = shape.slice(); this.dtype = dtype || "float32"; this.size = sizeFromShape(shape); this.strides = computeStrides(shape); this.dataId = dataId; this.id = id; this.rankType = this.rank < 5 ? this.rank.toString() : "higher"; } get rank() { return this.shape.length; } async buffer() { const vals = await this.data(); return opHandler.buffer(this.shape, this.dtype, vals); } bufferSync() { return opHandler.buffer(this.shape, this.dtype, this.dataSync()); } async array() { const vals = await this.data(); return toNestedArray(this.shape, vals); } arraySync() { return toNestedArray(this.shape, this.dataSync()); } async data() { this.throwIfDisposed(); const data2 = trackerFn().read(this.dataId); if (this.dtype === "string") { const bytes = await data2; try { return bytes.map((b) => decodeString(b)); } catch (_a) { throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes()."); } } return data2; } dataSync() { this.throwIfDisposed(); const data2 = trackerFn().readSync(this.dataId); if (this.dtype === "string") { try { return data2.map((b) => decodeString(b)); } catch (_a) { throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes()."); } } return data2; } async bytes() { this.throwIfDisposed(); const data2 = await trackerFn().read(this.dataId); if (this.dtype === "string") { return data2; } else { return new Uint8Array(data2.buffer); } } dispose() { if (this.isDisposed) { return; } trackerFn().disposeTensor(this); this.isDisposedInternal = true; } get isDisposed() { return this.isDisposedInternal; } throwIfDisposed() { if (this.isDisposed) { throw new Error(`Tensor is disposed.`); } } print(verbose = false) { return opHandler.print(this, verbose); } clone() { this.throwIfDisposed(); return opHandler.clone(this); } toString(verbose = false) { const vals = this.dataSync(); return tensorToString(vals, this.shape, this.dtype, verbose); } cast(dtype) { this.throwIfDisposed(); return opHandler.cast(this, dtype); } variable(trainable = true, name, dtype) { this.throwIfDisposed(); return trackerFn().makeVariable(this, trainable, name, dtype); } } Object.defineProperty(Tensor, Symbol.hasInstance, { value: (instance) => { return !!instance && instance.data != null && instance.dataSync != null && instance.throwIfDisposed != null; } }); class Variable extends Tensor { constructor(initialValue, trainable, name, tensorId) { super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId); this.trainable = trainable; this.name = name; } assign(newValue) { if (newValue.dtype !== this.dtype) { throw new Error(`dtype of the new value (${newValue.dtype}) and previous value (${this.dtype}) must match`); } if (!arraysEqual(newValue.shape, this.shape)) { throw new Error(`shape of the new value (${newValue.shape}) and previous value (${this.shape}) must match`); } trackerFn().disposeTensor(this); this.dataId = newValue.dataId; trackerFn().incRef(this, null); } dispose() { trackerFn().disposeVariable(this); this.isDisposedInternal = true; } } Object.defineProperty(Variable, Symbol.hasInstance, { value: (instance) => { return instance instanceof Tensor && instance.assign != null && instance.assign instanceof Function; } }); /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ (function(Rank) { Rank["R0"] = "R0"; Rank["R1"] = "R1"; Rank["R2"] = "R2"; Rank["R3"] = "R3"; Rank["R4"] = "R4"; Rank["R5"] = "R5"; Rank["R6"] = "R6"; })(exports3.Rank || (exports3.Rank = {})); var UpcastInt32AndMap; (function(UpcastInt32AndMap2) { UpcastInt32AndMap2["float32"] = "float32"; UpcastInt32AndMap2["int32"] = "int32"; UpcastInt32AndMap2["bool"] = "int32"; UpcastInt32AndMap2["complex64"] = "complex64"; })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); var UpcastBoolAndMap; (function(UpcastBoolAndMap2) { UpcastBoolAndMap2["float32"] = "float32"; UpcastBoolAndMap2["int32"] = "int32"; UpcastBoolAndMap2["bool"] = "bool"; UpcastBoolAndMap2["complex64"] = "complex64"; })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); var UpcastFloat32AndMap; (function(UpcastFloat32AndMap2) { UpcastFloat32AndMap2["float32"] = "float32"; UpcastFloat32AndMap2["int32"] = "float32"; UpcastFloat32AndMap2["bool"] = "float32"; UpcastFloat32AndMap2["complex64"] = "complex64"; })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); var UpcastComplex64AndMap; (function(UpcastComplex64AndMap2) { UpcastComplex64AndMap2["float32"] = "complex64"; UpcastComplex64AndMap2["int32"] = "complex64"; UpcastComplex64AndMap2["bool"] = "complex64"; UpcastComplex64AndMap2["complex64"] = "complex64"; })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); const upcastTypeMap = { float32: UpcastFloat32AndMap, int32: UpcastInt32AndMap, bool: UpcastBoolAndMap, complex64: UpcastComplex64AndMap }; function upcastType(typeA, typeB) { if (typeA === "string" || typeB === "string") { if (typeA === "string" && typeB === "string") { return "string"; } throw new Error(`Can not upcast ${typeA} with ${typeB}`); } return upcastTypeMap[typeA][typeB]; } function sumOutType(type) { return upcastType(type, "int32"); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function makeTypesMatch(a, b) { if (a.dtype === b.dtype) { return [a, b]; } const dtype = upcastType(a.dtype, b.dtype); return [a.cast(dtype), b.cast(dtype)]; } function assertTypesMatch(a, b) { assert(a.dtype === b.dtype, () => `The dtypes of the first(${a.dtype}) and second(${b.dtype}) input must match`); } function isTensorInList(tensor2, tensorList) { return tensorList.some((x) => x.id === tensor2.id); } function getTensorsInContainer(result) { const list = []; const seen = new Set(); walkTensorContainer(result, list, seen); return list; } function walkTensorContainer(container, list, seen) { if (container == null) { return; } if (container instanceof Tensor) { list.push(container); return; } if (!isIterable(container)) { return; } const iterable = container; for (const k in iterable) { const val = iterable[k]; if (!seen.has(val)) { seen.add(val); walkTensorContainer(val, list, seen); } } } function isIterable(obj) { return Array.isArray(obj) || typeof obj === "object"; } var tensor_util = /* @__PURE__ */ Object.freeze({ __proto__: null, makeTypesMatch, assertTypesMatch, isTensorInList, getTensorsInContainer }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class EngineState { constructor() { this.registeredVariables = {}; this.nextTapeNodeId = 0; this.numBytes = 0; this.numTensors = 0; this.numStringTensors = 0; this.numDataBuffers = 0; this.gradientDepth = 0; this.kernelDepth = 0; this.scopeStack = []; this.numDataMovesStack = []; this.nextScopeId = 0; this.tensorInfo = new WeakMap(); this.profiling = false; this.activeProfile = {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null}; } dispose() { for (const variableName in this.registeredVariables) { this.registeredVariables[variableName].dispose(); } } } class Engine { constructor(ENV3) { this.ENV = ENV3; this.registry = {}; this.registryFactory = {}; this.pendingBackendInitId = 0; this.state = new EngineState(); } async ready() { if (this.pendingBackendInit != null) { return this.pendingBackendInit.then(() => { }); } if (this.backendInstance != null) { return; } const sortedBackends = this.getSortedBackends(); for (let i = 0; i < sortedBackends.length; i++) { const backendName = sortedBackends[i]; const success = await this.initializeBackend(backendName).success; if (success) { await this.setBackend(backendName); return; } } throw new Error(`Could not initialize any backends, all backend initializations failed.`); } get backend() { if (this.pendingBackendInit != null) { throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`); } if (this.backendInstance == null) { const {name, asyncInit} = this.initializeBackendsAndReturnBest(); if (asyncInit) { throw new Error(`The highest priority backend '${name}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`); } this.setBackend(name); } return this.backendInstance; } backendNames() { return Object.keys(this.registryFactory); } findBackend(backendName) { if (!(backendName in this.registry)) { if (backendName in this.registryFactory) { const {asyncInit} = this.initializeBackend(backendName); if (asyncInit) { return null; } } else { return null; } } return this.registry[backendName]; } findBackendFactory(backendName) { if (!(backendName in this.registryFactory)) { return null; } return this.registryFactory[backendName].factory; } registerBackend(backendName, factory, priority = 1) { if (backendName in this.registryFactory) { console.warn(`${backendName} backend was already registered. Reusing existing backend factory.`); return false; } this.registryFactory[backendName] = {factory, priority}; return true; } async setBackend(backendName) { if (this.registryFactory[backendName] == null) { throw new Error(`Backend name '${backendName}' not found in registry`); } this.backendName = backendName; if (this.registry[backendName] == null) { this.backendInstance = null; const {success, asyncInit} = this.initializeBackend(backendName); const result = asyncInit ? await success : success; if (!result) { return false; } } this.backendInstance = this.registry[backendName]; this.setupRegisteredKernels(); this.profiler = new Profiler(this.backendInstance); return true; } setupRegisteredKernels() { const kernels = getKernelsForBackend(this.backendName); kernels.forEach((kernel) => { if (kernel.setupFunc != null) { kernel.setupFunc(this.backendInstance); } }); } disposeRegisteredKernels(backendName) { const kernels = getKernelsForBackend(backendName); kernels.forEach((kernel) => { if (kernel.disposeFunc != null) { kernel.disposeFunc(this.registry[backendName]); } }); } initializeBackend(backendName) { const registryFactoryEntry = this.registryFactory[backendName]; if (registryFactoryEntry == null) { throw new Error(`Cannot initialize backend ${backendName}, no registration found.`); } try { const backend2 = registryFactoryEntry.factory(); if (backend2 && !(backend2 instanceof KernelBackend2) && typeof backend2.then === "function") { const promiseId = ++this.pendingBackendInitId; const success = backend2.then((backendInstance) => { if (promiseId < this.pendingBackendInitId) { return false; } this.registry[backendName] = backendInstance; this.pendingBackendInit = null; return true; }).catch((err) => { if (promiseId < this.pendingBackendInitId) { return false; } this.pendingBackendInit = null; console.warn(`Initialization of backend ${backendName} failed`); console.warn(err.stack || err.message); return false; }); this.pendingBackendInit = success; return {success, asyncInit: true}; } else { this.registry[backendName] = backend2; return {success: true, asyncInit: false}; } } catch (err) { console.warn(`Initialization of backend ${backendName} failed`); console.warn(err.stack || err.message); return {success: false, asyncInit: false}; } } removeBackend(backendName) { if (!(backendName in this.registryFactory)) { throw new Error(`${backendName} backend not found in registry`); } if (this.backendName === backendName && this.pendingBackendInit != null) { this.pendingBackendInitId++; } if (backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } delete this.registryFactory[backendName]; if (this.backendName === backendName) { this.pendingBackendInit = null; this.backendName = null; this.backendInstance = null; } } getSortedBackends() { if (Object.keys(this.registryFactory).length === 0) { throw new Error("No backend found in registry."); } return Object.keys(this.registryFactory).sort((a, b) => { return this.registryFactory[b].priority - this.registryFactory[a].priority; }); } initializeBackendsAndReturnBest() { const sortedBackends = this.getSortedBackends(); for (let i = 0; i < sortedBackends.length; i++) { const backendName = sortedBackends[i]; const {success, asyncInit} = this.initializeBackend(backendName); if (asyncInit || success) { return {name: backendName, asyncInit}; } } throw new Error(`Could not initialize any backends, all backend initializations failed.`); } moveData(backend2, dataId) { const info = this.state.tensorInfo.get(dataId); const srcBackend = info.backend; const values = this.readSync(dataId); srcBackend.disposeData(dataId); info.backend = backend2; backend2.move(dataId, values, info.shape, info.dtype); if (this.shouldCheckForMemLeaks()) { this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; } } tidy(nameOrFn, fn) { let name = null; if (fn == null) { if (typeof nameOrFn !== "function") { throw new Error("Please provide a function to tidy()"); } fn = nameOrFn; } else { if (typeof nameOrFn !== "string" && !(nameOrFn instanceof String)) { throw new Error("When calling with two arguments, the first argument to tidy() must be a string"); } if (typeof fn !== "function") { throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function"); } name = nameOrFn; } let result; return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => { result = fn(); if (result instanceof Promise) { console.error("Cannot return a Promise inside of tidy."); } return result; }); } scopedRun(start, end, f) { start(); try { const res = f(); end(); return res; } catch (ex) { end(); throw ex; } } nextTensorId() { return Engine.nextTensorId++; } nextVariableId() { return Engine.nextVariableId++; } clone(x) { const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype); const inputs = {x}; const grad2 = (dy) => ({ x: () => { const dtype = "float32"; const gradInputs = {x: dy}; const attrs = {dtype}; return ENGINE.runKernelFunc((backend2) => backend2.cast(dy, dtype), gradInputs, null, Cast5, attrs); } }); const saved = []; this.addTapeNode(this.state.activeScope.name, inputs, [y], grad2, saved, {}); return y; } runKernel(kernelName, inputs, attrs, inputsToSave, outputsToSave) { const forwardFunc = null; const backwardsFunc = null; return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave); } shouldCheckForMemLeaks() { return this.ENV.getBool("IS_TEST"); } checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) { const numDataIdsAfter = this.backend.numDataIds(); let numOutputDataIds = 0; outInfos.forEach((info) => { numOutputDataIds += info.dtype === "complex64" ? 3 : 1; }); const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; if (dataIdsLeaked > 0) { throw new Error(`Backend '${this.backendName}' has an internal memory leak (${dataIdsLeaked} data ids) after running '${kernelName}'`); } } runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) { let outputs; let saved = []; const isTapeOn = this.isTapeOn(); if (kernelName == null) { kernelName = this.state.activeScope != null ? this.state.activeScope.name : ""; } const startingBytecount = this.state.numBytes; const startingNumTensors = this.state.numTensors; if (this.shouldCheckForMemLeaks()) { this.state.numDataMovesStack.push(0); } let kernelFunc3; const kernel = getKernel(kernelName, this.backendName); let out; if (kernel != null) { kernelFunc3 = () => { const numDataIdsBefore = this.backend.numDataIds(); out = kernel.kernelFunc({inputs, attrs, backend: this.backend}); const outInfos = Array.isArray(out) ? out : [out]; if (this.shouldCheckForMemLeaks()) { this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); } const outTensors = outInfos.map(({dataId, shape, dtype}) => this.makeTensorFromDataId(dataId, shape, dtype)); if (isTapeOn) { let tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors); if (tensorsToSave == null) { if (outputsToSave == null) { outputsToSave = []; } const outsToSave = outTensors.filter((_, i) => outputsToSave[i]); tensorsToSave = (inputsToSave || []).slice().concat(outsToSave); } saved = this.saveTensorsForBackwardMode(tensorsToSave); } return outTensors; }; } else { const saveFunc = (tensors) => { if (!isTapeOn) { return; } saved = tensors.map((tensor2) => this.keep(this.clone(tensor2))); }; kernelFunc3 = () => { const numDataIdsBefore = this.backend.numDataIds(); out = this.tidy(() => forwardFunc(this.backend, saveFunc)); const outs = Array.isArray(out) ? out : [out]; if (this.shouldCheckForMemLeaks()) { this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs); } return outs; }; } let kernelProfile; this.scopedRun(() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => { if (!this.ENV.getBool("DEBUG") && !this.state.profiling) { outputs = kernelFunc3(); } else { kernelProfile = this.profiler.profileKernel(kernelName, inputs, () => kernelFunc3()); if (this.ENV.getBool("DEBUG")) { this.profiler.logKernelProfile(kernelProfile); } outputs = kernelProfile.outputs; } }); if (isTapeOn) { this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs); } if (this.state.profiling) { this.state.activeProfile.kernels.push({ name: kernelName, bytesAdded: this.state.numBytes - startingBytecount, totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - startingNumTensors, totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(inputs).map((key) => inputs[key] != null ? inputs[key].shape : null), outputShapes: outputs.map((item) => item.shape), kernelTimeMs: kernelProfile.timeMs, extraInfo: kernelProfile.extraInfo }); } return Array.isArray(out) ? outputs : outputs[0]; } saveTensorsForBackwardMode(tensors) { const saved = tensors.map((tensor2) => this.keep(this.clone(tensor2))); return saved; } getTensorsForGradient(kernelName, inputs, outputs) { const gradConfig = getGradient(kernelName); if (gradConfig != null) { const inputsToSave = gradConfig.inputsToSave || []; const outputsToSave = gradConfig.outputsToSave || []; let inputTensorsToSave; if (gradConfig.saveAllInputs) { assert(Array.isArray(inputs), () => "saveAllInputs is true, expected inputs to be an array."); inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]); } else { inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]); } const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]); return inputTensorsToSave.concat(outputTensorsToSave); } return null; } makeTensor(values, shape, dtype, backend2) { if (values == null) { throw new Error("Values passed to engine.makeTensor() are null"); } dtype = dtype || "float32"; backend2 = backend2 || this.backend; let backendVals = values; if (dtype === "string" && isString(values[0])) { backendVals = values.map((d) => encodeString(d)); } const dataId = backend2.write(backendVals, shape, dtype); const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.incRef(t, backend2); if (dtype === "string") { const info = this.state.tensorInfo.get(dataId); const newBytes = bytesFromStringArray(backendVals); this.state.numBytes += newBytes - info.bytes; info.bytes = newBytes; } return t; } makeTensorFromDataId(dataId, shape, dtype, backend2) { dtype = dtype || "float32"; const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.incRef(t, backend2); return t; } makeVariable(initialValue, trainable = true, name, dtype) { name = name || this.nextVariableId().toString(); if (dtype != null && dtype !== initialValue.dtype) { initialValue = initialValue.cast(dtype); } const v = new Variable(initialValue, trainable, name, this.nextTensorId()); if (this.state.registeredVariables[v.name] != null) { throw new Error(`Variable with name ${v.name} was already registered`); } this.state.registeredVariables[v.name] = v; this.incRef(v, this.backend); return v; } incRef(a, backend2) { const refCount = this.state.tensorInfo.has(a.dataId) ? this.state.tensorInfo.get(a.dataId).refCount : 0; this.state.numTensors++; if (a.dtype === "string") { this.state.numStringTensors++; } if (refCount === 0) { this.state.numDataBuffers++; let bytes = 0; if (a.dtype !== "complex64" && a.dtype !== "string") { bytes = a.size * bytesPerElement(a.dtype); } this.state.tensorInfo.set(a.dataId, { backend: backend2 || this.backend, dtype: a.dtype, shape: a.shape, bytes, refCount: 0 }); this.state.numBytes += bytes; } this.state.tensorInfo.get(a.dataId).refCount++; if (!(a instanceof Variable)) { this.track(a); } } disposeTensor(a) { if (!this.state.tensorInfo.has(a.dataId)) { return; } this.state.numTensors--; if (a.dtype === "string") { this.state.numStringTensors--; } const info = this.state.tensorInfo.get(a.dataId); const refCount = info.refCount; if (refCount <= 1) { if (a.dtype !== "complex64") { this.state.numBytes -= info.bytes; } this.state.numDataBuffers--; info.backend.disposeData(a.dataId); this.state.tensorInfo.delete(a.dataId); } else { this.state.tensorInfo.get(a.dataId).refCount--; } } disposeVariables() { for (const varName in this.state.registeredVariables) { const v = this.state.registeredVariables[varName]; this.disposeVariable(v); } } disposeVariable(v) { this.disposeTensor(v); if (this.state.registeredVariables[v.name] != null) { delete this.state.registeredVariables[v.name]; } } memory() { const info = this.backend.memory(); info.numTensors = this.state.numTensors; info.numDataBuffers = this.state.numDataBuffers; info.numBytes = this.state.numBytes; if (this.state.numStringTensors > 0) { info.unreliable = true; if (info.reasons == null) { info.reasons = []; } info.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)"); } return info; } async profile(query) { this.state.profiling = true; const startBytes = this.state.numBytes; const startNumTensors = this.state.numTensors; this.state.activeProfile.kernels = []; this.state.activeProfile.result = await query(); this.state.profiling = false; this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map((d) => d.totalBytesSnapshot)); this.state.activeProfile.newBytes = this.state.numBytes - startBytes; this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors; for (const kernel of this.state.activeProfile.kernels) { kernel.kernelTimeMs = await kernel.kernelTimeMs; kernel.extraInfo = await kernel.extraInfo; } return this.state.activeProfile; } isTapeOn() { return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; } addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) { const tapeNode = {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved}; const gradConfig = getGradient(kernelName); if (gradConfig != null) { gradientsFunc = gradConfig.gradFunc; } if (gradientsFunc != null) { tapeNode.gradient = (dys) => { dys = dys.map((dy, i) => { if (dy == null) { const output = outputs[i]; const vals = makeZerosTypedArray(output.size, output.dtype); return this.makeTensor(vals, output.shape, output.dtype); } return dy; }); return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); }; } this.state.activeTape.push(tapeNode); } keep(result) { result.kept = true; return result; } startTape() { if (this.state.gradientDepth === 0) { this.state.activeTape = []; } this.state.gradientDepth++; } endTape() { this.state.gradientDepth--; } startScope(name) { const scopeInfo = { track: [], name: "unnamed scope", id: this.state.nextScopeId++ }; if (name) { scopeInfo.name = name; } this.state.scopeStack.push(scopeInfo); this.state.activeScope = scopeInfo; } endScope(result) { const tensorsToTrackInParent = getTensorsInContainer(result); const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map((t) => t.id)); for (let i = 0; i < this.state.activeScope.track.length; i++) { const tensor2 = this.state.activeScope.track[i]; if (!tensor2.kept && !tensorsToTrackInParentSet.has(tensor2.id)) { tensor2.dispose(); } } const oldScope = this.state.scopeStack.pop(); this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1]; tensorsToTrackInParent.forEach((tensor2) => { if (!tensor2.kept && tensor2.scopeId === oldScope.id) { this.track(tensor2); } }); } gradients(f, xs, dy, allowNoGradients = false) { assert(xs.length > 0, () => "gradients() received an empty list of xs."); if (dy != null && dy.dtype !== "float32") { throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`); } const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy("forward", f)); assert(y instanceof Tensor, () => "The result y returned by f() must be a tensor."); const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y."); } return this.tidy("backward", () => { const accumulatedGradientMap = {}; accumulatedGradientMap[y.id] = dy == null ? ones(y.shape) : dy; backpropagateGradients(accumulatedGradientMap, filteredTape, (f2) => this.tidy(f2), add); const grads2 = xs.map((x) => accumulatedGradientMap[x.id]); if (this.state.gradientDepth === 0) { this.state.activeTape.forEach((node) => { for (const tensor2 of node.saved) { tensor2.dispose(); } }); this.state.activeTape = null; } return {value: y, grads: grads2}; }); } customGrad(f) { assert(isFunction(f), () => "The f passed in customGrad(f) must be a function."); return (...inputs) => { assert(inputs.every((t) => t instanceof Tensor), () => "The args passed in customGrad(f)(x1, x2,...) must all be tensors"); let res; const inputMap = {}; inputs.forEach((input2, i) => { inputMap[i] = input2; }); return this.runKernelFunc((_, save) => { res = f(...[...inputs, save]); assert(res.value instanceof Tensor, () => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"); assert(isFunction(res.gradFunc), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."); return res.value; }, inputMap, (dy, saved) => { const gradRes = res.gradFunc(dy, saved); const grads2 = Array.isArray(gradRes) ? gradRes : [gradRes]; assert(grads2.length === inputs.length, () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."); assert(grads2.every((t) => t instanceof Tensor), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors."); const gradMap = {}; grads2.forEach((grad2, i) => { gradMap[i] = () => grad2; }); return gradMap; }); }; } readSync(dataId) { const info = this.state.tensorInfo.get(dataId); return info.backend.readSync(dataId); } read(dataId) { const info = this.state.tensorInfo.get(dataId); return info.backend.read(dataId); } async time(query) { const start = now2(); const timingInfo = await this.backend.time(query); timingInfo.wallMs = now2() - start; return timingInfo; } track(result) { if (this.state.activeScope != null) { result.scopeId = this.state.activeScope.id; this.state.activeScope.track.push(result); } return result; } get registeredVariables() { return this.state.registeredVariables; } reset() { this.pendingBackendInitId++; this.state.dispose(); this.ENV.reset(); this.state = new EngineState(); for (const backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } this.backendName = null; this.backendInstance = null; this.pendingBackendInit = null; } } Engine.nextTensorId = 0; Engine.nextVariableId = 0; function ones(shape) { const values = makeOnesTypedArray(sizeFromShape(shape), "float32"); return ENGINE.makeTensor(values, shape, "float32"); } function getOrMakeEngine() { const ns = getGlobalNamespace(); if (ns._tfengine == null) { const environment = new Environment(ns); ns._tfengine = new Engine(environment); } setEnvironmentGlobal(ns._tfengine.ENV); setTensorTracker(() => ns._tfengine); return ns._tfengine; } const ENGINE = getOrMakeEngine(); function add(a, b) { const inputs = {a, b}; return ENGINE.runKernelFunc((backend2, save) => { const res = backend2.add(a, b); save([a, b]); return res; }, inputs, null, Add3); } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function _isNavigatorDefined() { return typeof navigator !== "undefined" && navigator != null; } function isMobile() { if (_isNavigatorDefined()) { const a = navigator.userAgent || navigator.vendor || window.opera; return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(a) || /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(a.substr(0, 4)); } return false; } function isBrowser() { return typeof window !== "undefined" && window.document != null || typeof WorkerGlobalScope !== "undefined"; } var device_util = /* @__PURE__ */ Object.freeze({ __proto__: null, isMobile, isBrowser }); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ENV2 = env3(); ENV2.registerFlag("DEBUG", () => false, (debugValue) => { if (debugValue) { console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance."); } }); ENV2.registerFlag("IS_BROWSER", () => isBrowser()); ENV2.registerFlag("IS_NODE", () => typeof process !== "undefined" && typeof process.versions !== "undefined" && typeof process.versions.node !== "undefined"); ENV2.registerFlag("IS_CHROME", () => typeof navigator !== "undefined" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor)); ENV2.registerFlag("PROD", () => false); ENV2.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", () => ENV2.getBool("DEBUG")); ENV2.registerFlag("DEPRECATION_WARNINGS_ENABLED", () => true); ENV2.registerFlag("IS_TEST", () => false); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function inferShape(val, dtype) { let firstElem = val; if (isTypedArray(val)) { return dtype === "string" ? [] : [val.length]; } if (!Array.isArray(val)) { return []; } const shape = []; while (Array.isArray(firstElem) || isTypedArray(firstElem) && dtype !== "string") { shape.push(firstElem.length); firstElem = firstElem[0]; } if (Array.isArray(val) && env3().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY")) { deepAssertShapeConsistency(val, shape, []); } return shape; } function deepAssertShapeConsistency(val, shape, indices) { indices = indices || []; if (!Array.isArray(val) && !isTypedArray(val)) { assert(shape.length === 0, () => `Element arr[${indices.join("][")}] is a primitive, but should be an array/TypedArray of ${shape[0]} elements`); return; } assert(shape.length > 0, () => `Element arr[${indices.join("][")}] should be a primitive, but is an array of ${val.length} elements`); assert(val.length === shape[0], () => `Element arr[${indices.join("][")}] should have ${shape[0]} elements, but has ${val.length} elements`); const subShape = shape.slice(1); for (let i = 0; i < val.length; ++i) { deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); } } function assertDtype(expectedDtype, actualDType, argName, functionName) { if (expectedDtype == null) { return; } if (expectedDtype !== "numeric" && expectedDtype !== actualDType || expectedDtype === "numeric" && actualDType === "string") { throw new Error(`Argument '${argName}' passed to '${functionName}' must be ${expectedDtype} tensor, but got ${actualDType} tensor`); } } function convertToTensor(x, argName, functionName, parseAsDtype = "numeric") { if (x instanceof Tensor) { assertDtype(parseAsDtype, x.dtype, argName, functionName); return x; } let inferredDtype = inferDtype(x); if (inferredDtype !== "string" && ["bool", "int32", "float32"].indexOf(parseAsDtype) >= 0) { inferredDtype = parseAsDtype; } assertDtype(parseAsDtype, inferredDtype, argName, functionName); if (x == null || !isTypedArray(x) && !Array.isArray(x) && typeof x !== "number" && typeof x !== "boolean" && typeof x !== "string") { const type = x == null ? "null" : x.constructor.name; throw new Error(`Argument '${argName}' passed to '${functionName}' must be a Tensor or TensorLike, but got '${type}'`); } const inferredShape = inferShape(x, inferredDtype); if (!isTypedArray(x) && !Array.isArray(x)) { x = [x]; } const skipTypedArray = true; const values = inferredDtype !== "string" ? toTypedArray(x, inferredDtype) : flatten(x, [], skipTypedArray); return ENGINE.makeTensor(values, inferredShape, inferredDtype); } function convertToTensorArray(arg, argName, functionName, parseAsDtype = "numeric") { if (!Array.isArray(arg)) { throw new Error(`Argument ${argName} passed to ${functionName} must be a \`Tensor[]\` or \`TensorLike[]\``); } const tensors = arg; return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName), parseAsDtype); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const OP_SCOPE_SUFFIX = "__op"; function op(f) { const keys = Object.keys(f); if (keys.length !== 1) { throw new Error(`Please provide an object with a single key (operation name) mapping to a function. Got an object with ${keys.length} keys.`); } let opName = keys[0]; const fn = f[opName]; if (opName.endsWith("_")) { opName = opName.substring(0, opName.length - 1); } opName = opName + OP_SCOPE_SUFFIX; const f2 = (...args) => { ENGINE.startScope(opName); try { const result = fn(...args); if (isPromise(result)) { console.error("Cannot return a Promise inside of tidy."); } ENGINE.endScope(result); return result; } catch (ex) { ENGINE.endScope(null); throw ex; } }; Object.defineProperty(f2, "name", {value: opName, configurable: true}); return f2; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function complex_(real2, imag2) { const $real = convertToTensor(real2, "real", "complex"); const $imag = convertToTensor(imag2, "imag", "complex"); assertShapesMatch($real.shape, $imag.shape, `real and imag shapes, ${$real.shape} and ${$imag.shape}, must match in call to tf.complex().`); const forward = (backend2) => { return backend2.complex($real, $imag); }; const inputs = {real: $real, imag: $imag}; return ENGINE.runKernelFunc(forward, inputs, null, Complex); } const complex = op({complex_}); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function makeTensor(values, shape, inferredShape, dtype) { if (dtype == null) { dtype = inferDtype(values); } if (dtype === "complex64") { throw new Error(`Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).`); } if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== "number" && typeof values !== "boolean" && typeof values !== "string") { throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray"); } if (shape != null) { assertNonNegativeIntegerDimensions(shape); const providedSize = sizeFromShape(shape); const inferredSize = sizeFromShape(inferredShape); assert(providedSize === inferredSize, () => `Based on the provided shape, [${shape}], the tensor should have ${providedSize} values but has ${inferredSize}`); for (let i = 0; i < inferredShape.length; ++i) { const inferred = inferredShape[i]; const flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true; assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, () => `Error creating a new Tensor. Inferred shape (${inferredShape}) does not match the provided shape (${shape}). `); } } if (!isTypedArray(values) && !Array.isArray(values)) { values = [values]; } shape = shape || inferredShape; values = dtype !== "string" ? toTypedArray(values, dtype) : flatten(values, [], true); return ENGINE.makeTensor(values, shape, dtype); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function tensor(values, shape, dtype) { const inferredShape = inferShape(values, dtype); return makeTensor(values, shape, inferredShape, dtype); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const DTYPE_VALUE_SIZE_MAP = { float32: 4, float16: 2, int32: 4, uint16: 2, uint8: 1, bool: 1, complex64: 8 }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const NUM_BYTES_STRING_LENGTH = 4; async function encodeWeights(tensors, group) { const specs = []; const dataPromises = []; const names = Array.isArray(tensors) ? tensors.map((tensor2) => tensor2.name) : Object.keys(tensors); for (let i = 0; i < names.length; ++i) { const name = names[i]; const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name]; if (t.dtype !== "float32" && t.dtype !== "int32" && t.dtype !== "bool" && t.dtype !== "string" && t.dtype !== "complex64") { throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`); } const spec = {name, shape: t.shape, dtype: t.dtype}; if (t.dtype === "string") { const utf8bytes = new Promise(async (resolve) => { const vals = await t.bytes(); const totalNumBytes = vals.reduce((p2, c) => p2 + c.length, 0) + NUM_BYTES_STRING_LENGTH * vals.length; const bytes = new Uint8Array(totalNumBytes); let offset = 0; for (let i2 = 0; i2 < vals.length; i2++) { const val = vals[i2]; const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer); bytes.set(bytesOfLength, offset); offset += NUM_BYTES_STRING_LENGTH; bytes.set(val, offset); offset += val.length; } resolve(bytes); }); dataPromises.push(utf8bytes); } else { dataPromises.push(t.data()); } if (group != null) { spec.group = group; } specs.push(spec); } const tensorValues = await Promise.all(dataPromises); return {data: concatenateTypedArrays(tensorValues), specs}; } function decodeWeights(buffer3, specs) { const out = {}; let float16Decode; let offset = 0; for (const spec of specs) { const name = spec.name; const dtype = spec.dtype; const shape = spec.shape; const size = sizeFromShape(shape); let values; if ("quantization" in spec) { const quantization = spec.quantization; if (quantization.dtype === "uint8" || quantization.dtype === "uint16") { if (!("min" in quantization && "scale" in quantization)) { throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} doesn't have corresponding metadata min and scale.`); } } else if (quantization.dtype === "float16") { if (dtype !== "float32") { throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} which only supports weights of type float32 not ${dtype}.`); } } else { throw new Error(`Weight ${spec.name} has unknown quantization dtype ${quantization.dtype}. Supported quantization dtypes are: 'uint8', 'uint16', and 'float16'.`); } const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; const byteBuffer = buffer3.slice(offset, offset + size * quantizationSizeFactor); const quantizedArray = quantization.dtype === "uint8" ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); if (dtype === "float32") { if (quantization.dtype === "uint8" || quantization.dtype === "uint16") { values = new Float32Array(quantizedArray.length); for (let i = 0; i < quantizedArray.length; i++) { const v = quantizedArray[i]; values[i] = v * quantization.scale + quantization.min; } } else if (quantization.dtype === "float16") { if (float16Decode === void 0) { float16Decode = getFloat16Decoder(); } values = float16Decode(quantizedArray); } else { throw new Error(`Unsupported quantization type ${quantization.dtype} for weight type float32.`); } } else if (dtype === "int32") { if (quantization.dtype !== "uint8" && quantization.dtype !== "uint16") { throw new Error(`Unsupported quantization type ${quantization.dtype} for weight type int32.`); } values = new Int32Array(quantizedArray.length); for (let i = 0; i < quantizedArray.length; i++) { const v = quantizedArray[i]; values[i] = Math.round(v * quantization.scale + quantization.min); } } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * quantizationSizeFactor; } else if (dtype === "string") { const size2 = sizeFromShape(spec.shape); values = []; for (let i = 0; i < size2; i++) { const byteLength = new Uint32Array(buffer3.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; offset += NUM_BYTES_STRING_LENGTH; const bytes = new Uint8Array(buffer3.slice(offset, offset + byteLength)); values.push(bytes); offset += byteLength; } } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; const byteBuffer = buffer3.slice(offset, offset + size * dtypeFactor); if (dtype === "float32") { values = new Float32Array(byteBuffer); } else if (dtype === "int32") { values = new Int32Array(byteBuffer); } else if (dtype === "bool") { values = new Uint8Array(byteBuffer); } else if (dtype === "complex64") { values = new Float32Array(byteBuffer); const real2 = new Float32Array(values.length / 2); const image3 = new Float32Array(values.length / 2); for (let i = 0; i < real2.length; i++) { real2[i] = values[i * 2]; image3[i] = values[i * 2 + 1]; } const realTensor = tensor(real2, shape, "float32"); const imageTensor = tensor(image3, shape, "float32"); out[name] = complex(realTensor, imageTensor); realTensor.dispose(); imageTensor.dispose(); } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * dtypeFactor; } if (dtype !== "complex64") { out[name] = tensor(values, shape, dtype); } } return out; } function concatenateTypedArrays(xs) { if (xs === null) { throw new Error(`Invalid input value: ${JSON.stringify(xs)}`); } let totalByteLength = 0; const normalizedXs = []; xs.forEach((x) => { totalByteLength += x.byteLength; normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : new x.constructor(x)); if (!(x instanceof Float32Array || x instanceof Int32Array || x instanceof Uint8Array)) { throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`); } }); const y = new Uint8Array(totalByteLength); let offset = 0; normalizedXs.forEach((x) => { y.set(new Uint8Array(x.buffer), offset); offset += x.byteLength; }); return y.buffer; } const useNodeBuffer = typeof Buffer !== "undefined" && (typeof Blob === "undefined" || typeof atob === "undefined" || typeof btoa === "undefined"); function stringByteLength(str) { if (useNodeBuffer) { return Buffer.byteLength(str); } return new Blob([str]).size; } function arrayBufferToBase64String(buffer3) { if (useNodeBuffer) { return Buffer.from(buffer3).toString("base64"); } const buf = new Uint8Array(buffer3); let s = ""; for (let i = 0, l = buf.length; i < l; i++) { s += String.fromCharCode(buf[i]); } return btoa(s); } function base64StringToArrayBuffer(str) { if (useNodeBuffer) { const buf = Buffer.from(str, "base64"); return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); } const s = atob(str); const buffer3 = new Uint8Array(s.length); for (let i = 0; i < s.length; ++i) { buffer3.set([s.charCodeAt(i)], i); } return buffer3.buffer; } function concatenateArrayBuffers(buffers) { if (buffers.length === 1) { return buffers[0]; } let totalByteLength = 0; buffers.forEach((buffer3) => { totalByteLength += buffer3.byteLength; }); const temp = new Uint8Array(totalByteLength); let offset = 0; buffers.forEach((buffer3) => { temp.set(new Uint8Array(buffer3), offset); offset += buffer3.byteLength; }); return temp.buffer; } function basename(path) { const SEPARATOR = "/"; path = path.trim(); while (path.endsWith(SEPARATOR)) { path = path.slice(0, path.length - 1); } const items = path.split(SEPARATOR); return items[items.length - 1]; } function getModelArtifactsInfoForJSON(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("Expected JSON model topology, received ArrayBuffer."); } return { dateSaved: new Date(), modelTopologyType: "JSON", modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)), weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), weightDataBytes: modelArtifacts.weightData == null ? 0 : modelArtifacts.weightData.byteLength }; } function computeFloat16MantisaTable() { const convertMantissa = (i) => { let m = i << 13; let e = 0; while ((m & 8388608) === 0) { e -= 8388608; m <<= 1; } m &= ~8388608; e += 947912704; return m | e; }; const mantisaTable = new Uint32Array(2048); mantisaTable[0] = 0; for (let i = 1; i < 1024; i++) { mantisaTable[i] = convertMantissa(i); } for (let i = 1024; i < 2048; i++) { mantisaTable[i] = 939524096 + (i - 1024 << 13); } return mantisaTable; } function computeFloat16ExponentTable() { const exponentTable = new Uint32Array(64); exponentTable[0] = 0; exponentTable[31] = 1199570944; exponentTable[32] = 2147483648; exponentTable[63] = 3347054592; for (let i = 1; i < 31; i++) { exponentTable[i] = i << 23; } for (let i = 33; i < 63; i++) { exponentTable[i] = 2147483648 + (i - 32 << 23); } return exponentTable; } function computeFloat16OffsetTable() { const offsetTable = new Uint32Array(64); for (let i = 0; i < 64; i++) { offsetTable[i] = 1024; } offsetTable[0] = offsetTable[32] = 0; return offsetTable; } function getFloat16Decoder() { const mantisaTable = computeFloat16MantisaTable(); const exponentTable = computeFloat16ExponentTable(); const offsetTable = computeFloat16OffsetTable(); return (quantizedArray) => { const buffer3 = new ArrayBuffer(4 * quantizedArray.length); const bufferUint32View = new Uint32Array(buffer3); for (let index2 = 0; index2 < quantizedArray.length; index2++) { const float16Bits = quantizedArray[index2]; const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 1023)] + exponentTable[float16Bits >> 10]; bufferUint32View[index2] = float32Bits; } return new Float32Array(buffer3); }; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class IORouterRegistry { constructor() { this.saveRouters = []; this.loadRouters = []; } static getInstance() { if (IORouterRegistry.instance == null) { IORouterRegistry.instance = new IORouterRegistry(); } return IORouterRegistry.instance; } static registerSaveRouter(saveRouter) { IORouterRegistry.getInstance().saveRouters.push(saveRouter); } static registerLoadRouter(loadRouter) { IORouterRegistry.getInstance().loadRouters.push(loadRouter); } static getSaveHandlers(url) { return IORouterRegistry.getHandlers(url, "save"); } static getLoadHandlers(url, loadOptions) { return IORouterRegistry.getHandlers(url, "load", loadOptions); } static getHandlers(url, handlerType, loadOptions) { const validHandlers = []; const routers = handlerType === "load" ? IORouterRegistry.getInstance().loadRouters : IORouterRegistry.getInstance().saveRouters; routers.forEach((router) => { const handler = router(url, loadOptions); if (handler !== null) { validHandlers.push(handler); } }); return validHandlers; } } const registerSaveRouter = (loudRouter) => IORouterRegistry.registerSaveRouter(loudRouter); const registerLoadRouter = (loudRouter) => IORouterRegistry.registerLoadRouter(loudRouter); const getSaveHandlers = (url) => IORouterRegistry.getSaveHandlers(url); const getLoadHandlers = (url, loadOptions) => IORouterRegistry.getLoadHandlers(url, loadOptions); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const DATABASE_NAME = "tensorflowjs"; const DATABASE_VERSION = 1; const MODEL_STORE_NAME = "models_store"; const INFO_STORE_NAME = "model_info_store"; async function deleteDatabase() { const idbFactory = getIndexedDBFactory(); return new Promise((resolve, reject) => { const deleteRequest = idbFactory.deleteDatabase(DATABASE_NAME); deleteRequest.onsuccess = () => resolve(); deleteRequest.onerror = (error) => reject(error); }); } function getIndexedDBFactory() { if (!env3().getBool("IS_BROWSER")) { throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser."); } const theWindow = typeof window === "undefined" ? self : window; const factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB; if (factory == null) { throw new Error("The current browser does not appear to support IndexedDB."); } return factory; } function setUpDatabase(openRequest) { const db = openRequest.result; db.createObjectStore(MODEL_STORE_NAME, {keyPath: "modelPath"}); db.createObjectStore(INFO_STORE_NAME, {keyPath: "modelPath"}); } class BrowserIndexedDB { constructor(modelPath) { this.indexedDB = getIndexedDBFactory(); if (modelPath == null || !modelPath) { throw new Error("For IndexedDB, modelPath must not be null, undefined or empty."); } this.modelPath = modelPath; } async save(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet."); } return this.databaseAction(this.modelPath, modelArtifacts); } async load() { return this.databaseAction(this.modelPath); } databaseAction(modelPath, modelArtifacts) { return new Promise((resolve, reject) => { const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); openRequest.onupgradeneeded = () => setUpDatabase(openRequest); openRequest.onsuccess = () => { const db = openRequest.result; if (modelArtifacts == null) { const modelTx = db.transaction(MODEL_STORE_NAME, "readonly"); const modelStore = modelTx.objectStore(MODEL_STORE_NAME); const getRequest = modelStore.get(this.modelPath); getRequest.onsuccess = () => { if (getRequest.result == null) { db.close(); return reject(new Error(`Cannot find model with path '${this.modelPath}' in IndexedDB.`)); } else { resolve(getRequest.result.modelArtifacts); } }; getRequest.onerror = (error) => { db.close(); return reject(getRequest.error); }; modelTx.oncomplete = () => db.close(); } else { const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); const infoTx = db.transaction(INFO_STORE_NAME, "readwrite"); let infoStore = infoTx.objectStore(INFO_STORE_NAME); const putInfoRequest = infoStore.put({modelPath: this.modelPath, modelArtifactsInfo}); let modelTx; putInfoRequest.onsuccess = () => { modelTx = db.transaction(MODEL_STORE_NAME, "readwrite"); const modelStore = modelTx.objectStore(MODEL_STORE_NAME); const putModelRequest = modelStore.put({ modelPath: this.modelPath, modelArtifacts, modelArtifactsInfo }); putModelRequest.onsuccess = () => resolve({modelArtifactsInfo}); putModelRequest.onerror = (error) => { infoStore = infoTx.objectStore(INFO_STORE_NAME); const deleteInfoRequest = infoStore.delete(this.modelPath); deleteInfoRequest.onsuccess = () => { db.close(); return reject(putModelRequest.error); }; deleteInfoRequest.onerror = (error2) => { db.close(); return reject(putModelRequest.error); }; }; }; putInfoRequest.onerror = (error) => { db.close(); return reject(putInfoRequest.error); }; infoTx.oncomplete = () => { if (modelTx == null) { db.close(); } else { modelTx.oncomplete = () => db.close(); } }; } }; openRequest.onerror = (error) => reject(openRequest.error); }); } } BrowserIndexedDB.URL_SCHEME = "indexeddb://"; const indexedDBRouter = (url) => { if (!env3().getBool("IS_BROWSER")) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(indexedDBRouter); IORouterRegistry.registerLoadRouter(indexedDBRouter); function browserIndexedDB(modelPath) { return new BrowserIndexedDB(modelPath); } function maybeStripScheme(key) { return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? key.slice(BrowserIndexedDB.URL_SCHEME.length) : key; } class BrowserIndexedDBManager { constructor() { this.indexedDB = getIndexedDBFactory(); } async listModels() { return new Promise((resolve, reject) => { const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); openRequest.onupgradeneeded = () => setUpDatabase(openRequest); openRequest.onsuccess = () => { const db = openRequest.result; const tx = db.transaction(INFO_STORE_NAME, "readonly"); const store = tx.objectStore(INFO_STORE_NAME); const getAllInfoRequest = store.getAll(); getAllInfoRequest.onsuccess = () => { const out = {}; for (const item of getAllInfoRequest.result) { out[item.modelPath] = item.modelArtifactsInfo; } resolve(out); }; getAllInfoRequest.onerror = (error) => { db.close(); return reject(getAllInfoRequest.error); }; tx.oncomplete = () => db.close(); }; openRequest.onerror = (error) => reject(openRequest.error); }); } async removeModel(path) { path = maybeStripScheme(path); return new Promise((resolve, reject) => { const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); openRequest.onupgradeneeded = () => setUpDatabase(openRequest); openRequest.onsuccess = () => { const db = openRequest.result; const infoTx = db.transaction(INFO_STORE_NAME, "readwrite"); const infoStore = infoTx.objectStore(INFO_STORE_NAME); const getInfoRequest = infoStore.get(path); let modelTx; getInfoRequest.onsuccess = () => { if (getInfoRequest.result == null) { db.close(); return reject(new Error(`Cannot find model with path '${path}' in IndexedDB.`)); } else { const deleteInfoRequest = infoStore.delete(path); const deleteModelData = () => { modelTx = db.transaction(MODEL_STORE_NAME, "readwrite"); const modelStore = modelTx.objectStore(MODEL_STORE_NAME); const deleteModelRequest = modelStore.delete(path); deleteModelRequest.onsuccess = () => resolve(getInfoRequest.result.modelArtifactsInfo); deleteModelRequest.onerror = (error) => reject(getInfoRequest.error); }; deleteInfoRequest.onsuccess = deleteModelData; deleteInfoRequest.onerror = (error) => { deleteModelData(); db.close(); return reject(getInfoRequest.error); }; } }; getInfoRequest.onerror = (error) => { db.close(); return reject(getInfoRequest.error); }; infoTx.oncomplete = () => { if (modelTx == null) { db.close(); } else { modelTx.oncomplete = () => db.close(); } }; }; openRequest.onerror = (error) => reject(openRequest.error); }); } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const PATH_SEPARATOR = "/"; const PATH_PREFIX = "tensorflowjs_models"; const INFO_SUFFIX = "info"; const MODEL_TOPOLOGY_SUFFIX = "model_topology"; const WEIGHT_SPECS_SUFFIX = "weight_specs"; const WEIGHT_DATA_SUFFIX = "weight_data"; const MODEL_METADATA_SUFFIX = "model_metadata"; function purgeLocalStorageArtifacts() { if (!env3().getBool("IS_BROWSER") || typeof window === "undefined" || typeof window.localStorage === "undefined") { throw new Error("purgeLocalStorageModels() cannot proceed because local storage is unavailable in the current environment."); } const LS = window.localStorage; const purgedModelPaths = []; for (let i = 0; i < LS.length; ++i) { const key = LS.key(i); const prefix = PATH_PREFIX + PATH_SEPARATOR; if (key.startsWith(prefix) && key.length > prefix.length) { LS.removeItem(key); const modelName = getModelPathFromKey(key); if (purgedModelPaths.indexOf(modelName) === -1) { purgedModelPaths.push(modelName); } } } return purgedModelPaths; } function getModelKeys(path) { return { info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) }; } function getModelPathFromKey(key) { const items = key.split(PATH_SEPARATOR); if (items.length < 3) { throw new Error(`Invalid key format: ${key}`); } return items.slice(1, items.length - 1).join(PATH_SEPARATOR); } function maybeStripScheme$1(key) { return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? key.slice(BrowserLocalStorage.URL_SCHEME.length) : key; } class BrowserLocalStorage { constructor(modelPath) { if (!env3().getBool("IS_BROWSER") || typeof window === "undefined" || typeof window.localStorage === "undefined") { throw new Error("The current environment does not support local storage."); } this.LS = window.localStorage; if (modelPath == null || !modelPath) { throw new Error("For local storage, modelPath must not be null, undefined or empty."); } this.modelPath = modelPath; this.keys = getModelKeys(this.modelPath); } async save(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet."); } else { const topology = JSON.stringify(modelArtifacts.modelTopology); const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); this.LS.setItem(this.keys.topology, topology); this.LS.setItem(this.keys.weightSpecs, weightSpecs); this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData)); this.LS.setItem(this.keys.modelMetadata, JSON.stringify({ format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy, userDefinedMetadata: modelArtifacts.userDefinedMetadata })); return {modelArtifactsInfo}; } catch (err) { this.LS.removeItem(this.keys.info); this.LS.removeItem(this.keys.topology); this.LS.removeItem(this.keys.weightSpecs); this.LS.removeItem(this.keys.weightData); this.LS.removeItem(this.keys.modelMetadata); throw new Error(`Failed to save model '${this.modelPath}' to local storage: size quota being exceeded is a possible cause of this failure: modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`); } } } async load() { const info = JSON.parse(this.LS.getItem(this.keys.info)); if (info == null) { throw new Error(`In local storage, there is no model with name '${this.modelPath}'`); } if (info.modelTopologyType !== "JSON") { throw new Error("BrowserLocalStorage does not support loading non-JSON model topology yet."); } const out = {}; const topology = JSON.parse(this.LS.getItem(this.keys.topology)); if (topology == null) { throw new Error(`In local storage, the topology of model '${this.modelPath}' is missing.`); } out.modelTopology = topology; const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); if (weightSpecs == null) { throw new Error(`In local storage, the weight specs of model '${this.modelPath}' are missing.`); } out.weightSpecs = weightSpecs; const metadataString = this.LS.getItem(this.keys.modelMetadata); if (metadataString != null) { const metadata = JSON.parse(metadataString); out.format = metadata["format"]; out.generatedBy = metadata["generatedBy"]; out.convertedBy = metadata["convertedBy"]; out.userDefinedMetadata = metadata["userDefinedMetadata"]; } const weightDataBase64 = this.LS.getItem(this.keys.weightData); if (weightDataBase64 == null) { throw new Error(`In local storage, the binary weight values of model '${this.modelPath}' are missing.`); } out.weightData = base64StringToArrayBuffer(weightDataBase64); return out; } } BrowserLocalStorage.URL_SCHEME = "localstorage://"; const localStorageRouter = (url) => { if (!env3().getBool("IS_BROWSER")) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(localStorageRouter); IORouterRegistry.registerLoadRouter(localStorageRouter); function browserLocalStorage(modelPath) { return new BrowserLocalStorage(modelPath); } class BrowserLocalStorageManager { constructor() { assert(env3().getBool("IS_BROWSER"), () => "Current environment is not a web browser"); assert(typeof window === "undefined" || typeof window.localStorage !== "undefined", () => "Current browser does not appear to support localStorage"); this.LS = window.localStorage; } async listModels() { const out = {}; const prefix = PATH_PREFIX + PATH_SEPARATOR; const suffix = PATH_SEPARATOR + INFO_SUFFIX; for (let i = 0; i < this.LS.length; ++i) { const key = this.LS.key(i); if (key.startsWith(prefix) && key.endsWith(suffix)) { const modelPath = getModelPathFromKey(key); out[modelPath] = JSON.parse(this.LS.getItem(key)); } } return out; } async removeModel(path) { path = maybeStripScheme$1(path); const keys = getModelKeys(path); if (this.LS.getItem(keys.info) == null) { throw new Error(`Cannot find model at path '${path}'`); } const info = JSON.parse(this.LS.getItem(keys.info)); this.LS.removeItem(keys.info); this.LS.removeItem(keys.topology); this.LS.removeItem(keys.weightSpecs); this.LS.removeItem(keys.weightData); return info; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const URL_SCHEME_SUFFIX = "://"; class ModelStoreManagerRegistry { constructor() { this.managers = {}; } static getInstance() { if (ModelStoreManagerRegistry.instance == null) { ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry(); } return ModelStoreManagerRegistry.instance; } static registerManager(scheme, manager) { assert(scheme != null, () => "scheme must not be undefined or null."); if (scheme.endsWith(URL_SCHEME_SUFFIX)) { scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX)); } assert(scheme.length > 0, () => "scheme must not be an empty string."); const registry = ModelStoreManagerRegistry.getInstance(); assert(registry.managers[scheme] == null, () => `A model store manager is already registered for scheme '${scheme}'.`); registry.managers[scheme] = manager; } static getManager(scheme) { const manager = this.getInstance().managers[scheme]; if (manager == null) { throw new Error(`Cannot find model manager for scheme '${scheme}'`); } return manager; } static getSchemes() { return Object.keys(this.getInstance().managers); } } function parseURL(url) { if (url.indexOf(URL_SCHEME_SUFFIX) === -1) { throw new Error(`The url string provided does not contain a scheme. Supported schemes are: ${ModelStoreManagerRegistry.getSchemes().join(",")}`); } return { scheme: url.split(URL_SCHEME_SUFFIX)[0], path: url.split(URL_SCHEME_SUFFIX)[1] }; } async function cloneModelInternal(sourceURL, destURL, deleteSource = false) { assert(sourceURL !== destURL, () => `Old path and new path are the same: '${sourceURL}'`); const loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL); assert(loadHandlers.length > 0, () => `Copying failed because no load handler is found for source URL ${sourceURL}.`); assert(loadHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) load handlers for source URL ${sourceURL}.`); const loadHandler = loadHandlers[0]; const saveHandlers = IORouterRegistry.getSaveHandlers(destURL); assert(saveHandlers.length > 0, () => `Copying failed because no save handler is found for destination URL ${destURL}.`); assert(saveHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) save handlers for destination URL ${destURL}.`); const saveHandler = saveHandlers[0]; const sourceScheme = parseURL(sourceURL).scheme; const sourcePath = parseURL(sourceURL).path; const sameMedium = sourceScheme === parseURL(sourceURL).scheme; const modelArtifacts = await loadHandler.load(); if (deleteSource && sameMedium) { await ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath); } const saveResult = await saveHandler.save(modelArtifacts); if (deleteSource && !sameMedium) { await ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath); } return saveResult.modelArtifactsInfo; } async function listModels() { const schemes = ModelStoreManagerRegistry.getSchemes(); const out = {}; for (const scheme of schemes) { const schemeOut = await ModelStoreManagerRegistry.getManager(scheme).listModels(); for (const path in schemeOut) { const url = scheme + URL_SCHEME_SUFFIX + path; out[url] = schemeOut[path]; } } return out; } async function removeModel(url) { const schemeAndPath = parseURL(url); const manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme); return manager.removeModel(schemeAndPath.path); } async function copyModel(sourceURL, destURL) { const deleteSource = false; return cloneModelInternal(sourceURL, destURL, deleteSource); } async function moveModel(sourceURL, destURL) { const deleteSource = true; return cloneModelInternal(sourceURL, destURL, deleteSource); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class PlatformBrowser { fetch(path, init2) { return fetch(path, init2); } now() { return performance.now(); } encode(text, encoding) { if (encoding !== "utf-8" && encoding !== "utf8") { throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`); } if (this.textEncoder == null) { this.textEncoder = new TextEncoder(); } return this.textEncoder.encode(text); } decode(bytes, encoding) { return new TextDecoder(encoding).decode(bytes); } } if (env3().get("IS_BROWSER")) { env3().setPlatform("browser", new PlatformBrowser()); try { ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager()); } catch (err) { } try { ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager()); } catch (err) { } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const getNodeFetch = { importFetch: () => require_lib() }; let systemFetch; function resetSystemFetch() { systemFetch = null; } function setSystemFetch(fetchFn) { systemFetch = fetchFn; } function getSystemFetch() { return systemFetch; } class PlatformNode { constructor() { this.util = require("util"); this.textEncoder = new this.util.TextEncoder(); } fetch(path, requestInits) { if (env3().global.fetch != null) { return env3().global.fetch(path, requestInits); } if (systemFetch == null) { systemFetch = getNodeFetch.importFetch(); } return systemFetch(path, requestInits); } now() { const time2 = process.hrtime(); return time2[0] * 1e3 + time2[1] / 1e6; } encode(text, encoding) { if (encoding !== "utf-8" && encoding !== "utf8") { throw new Error(`Node built-in encoder only supports utf-8, but got ${encoding}`); } return this.textEncoder.encode(text); } decode(bytes, encoding) { if (bytes.length === 0) { return ""; } return new this.util.TextDecoder(encoding).decode(bytes); } } if (env3().get("IS_NODE")) { env3().setPlatform("node", new PlatformNode()); } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function buffer2(shape, dtype = "float32", values) { dtype = dtype || "float32"; assertNonNegativeIntegerDimensions(shape); return new TensorBuffer(shape, dtype, values); } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function cast_(x, dtype) { const $x = convertToTensor(x, "x", "cast"); if (!isValidDtype(dtype)) { throw new Error(`Failed to cast to unknown dtype ${dtype}`); } if (dtype === "string" && $x.dtype !== "string" || dtype !== "string" && $x.dtype === "string") { throw new Error("Only strings can be casted to strings"); } const inputs = {x: $x}; const attrs = {dtype}; return ENGINE.runKernelFunc((backend2) => backend2.cast($x, dtype), inputs, null, Cast5, attrs); } const cast2 = op({cast_}); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function clone_(x) { const $x = convertToTensor(x, "x", "clone", null); const forward = () => ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype); const inputs = {x: $x}; return ENGINE.runKernelFunc(forward, inputs, null, Identity5); } const clone = op({clone_}); /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function print2(x, verbose = false) { console.log(x.toString(verbose)); } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ getOrMakeEngine(); const opHandler$1 = { buffer: buffer2, cast: cast2, clone, print: print2 }; setOpHandler(opHandler$1); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const DEFAULT_FILE_NAME_PREFIX = "model"; const DEFAULT_JSON_EXTENSION_NAME = ".json"; const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = ".weights.bin"; function defer(f) { return new Promise((resolve) => setTimeout(resolve)).then(f); } class BrowserDownloads { constructor(fileNamePrefix) { if (!env3().getBool("IS_BROWSER")) { throw new Error("browserDownloads() cannot proceed because the current environment is not a browser."); } if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) { fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length); } if (fileNamePrefix == null || fileNamePrefix.length === 0) { fileNamePrefix = DEFAULT_FILE_NAME_PREFIX; } this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME; this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME; } async save(modelArtifacts) { if (typeof document === "undefined") { throw new Error("Browser downloads are not supported in this environment since `document` is not present"); } const weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], {type: "application/octet-stream"})); if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("BrowserDownloads.save() does not support saving model topology in binary formats yet."); } else { const weightsManifest = [{ paths: ["./" + this.weightDataFileName], weights: modelArtifacts.weightSpecs }]; const modelTopologyAndWeightManifest = { modelTopology: modelArtifacts.modelTopology, format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy, weightsManifest }; const modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {type: "application/json"})); const jsonAnchor = this.jsonAnchor == null ? document.createElement("a") : this.jsonAnchor; jsonAnchor.download = this.modelTopologyFileName; jsonAnchor.href = modelTopologyAndWeightManifestURL; await defer(() => jsonAnchor.dispatchEvent(new MouseEvent("click"))); if (modelArtifacts.weightData != null) { const weightDataAnchor = this.weightDataAnchor == null ? document.createElement("a") : this.weightDataAnchor; weightDataAnchor.download = this.weightDataFileName; weightDataAnchor.href = weightsURL; await defer(() => weightDataAnchor.dispatchEvent(new MouseEvent("click"))); } return {modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts)}; } } } BrowserDownloads.URL_SCHEME = "downloads://"; class BrowserFiles { constructor(files) { if (files == null || files.length < 1) { throw new Error(`When calling browserFiles, at least 1 file is required, but received ${files}`); } this.files = files; } async load() { const jsonFile = this.files[0]; const weightFiles = this.files.slice(1); return new Promise((resolve, reject) => { const jsonReader = new FileReader(); jsonReader.onload = (event) => { const modelJSON = JSON.parse(event.target.result); const modelTopology = modelJSON.modelTopology; if (modelTopology == null) { reject(new Error(`modelTopology field is missing from file ${jsonFile.name}`)); return; } if (weightFiles.length === 0) { resolve({modelTopology}); } const weightsManifest = modelJSON.weightsManifest; if (weightsManifest == null) { reject(new Error(`weightManifest field is missing from file ${jsonFile.name}`)); return; } let pathToFile; try { pathToFile = this.checkManifestAndWeightFiles(weightsManifest, weightFiles); } catch (err) { reject(err); return; } const weightSpecs = []; const paths = []; const perFileBuffers = []; weightsManifest.forEach((weightsGroup) => { weightsGroup.paths.forEach((path) => { paths.push(path); perFileBuffers.push(null); }); weightSpecs.push(...weightsGroup.weights); }); weightsManifest.forEach((weightsGroup) => { weightsGroup.paths.forEach((path) => { const weightFileReader = new FileReader(); weightFileReader.onload = (event2) => { const weightData = event2.target.result; const index2 = paths.indexOf(path); perFileBuffers[index2] = weightData; if (perFileBuffers.indexOf(null) === -1) { resolve({ modelTopology, weightSpecs, weightData: concatenateArrayBuffers(perFileBuffers), format: modelJSON.format, generatedBy: modelJSON.generatedBy, convertedBy: modelJSON.convertedBy, userDefinedMetadata: modelJSON.userDefinedMetadata }); } }; weightFileReader.onerror = (error) => reject(`Failed to weights data from file of path '${path}'.`); weightFileReader.readAsArrayBuffer(pathToFile[path]); }); }); }; jsonReader.onerror = (error) => reject(`Failed to read model topology and weights manifest JSON from file '${jsonFile.name}'. BrowserFiles supports loading Keras-style tf.Model artifacts only.`); jsonReader.readAsText(jsonFile); }); } checkManifestAndWeightFiles(manifest, files) { const basenames = []; const fileNames = files.map((file) => basename(file.name)); const pathToFile = {}; for (const group of manifest) { group.paths.forEach((path) => { const pathBasename = basename(path); if (basenames.indexOf(pathBasename) !== -1) { throw new Error(`Duplicate file basename found in weights manifest: '${pathBasename}'`); } basenames.push(pathBasename); if (fileNames.indexOf(pathBasename) === -1) { throw new Error(`Weight file with basename '${pathBasename}' is not provided.`); } else { pathToFile[path] = files[fileNames.indexOf(pathBasename)]; } }); } if (basenames.length !== files.length) { throw new Error(`Mismatch in the number of files in weights manifest (${basenames.length}) and the number of weight files provided (${files.length}).`); } return pathToFile; } } const browserDownloadsRouter = (url) => { if (!env3().getBool("IS_BROWSER")) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(browserDownloadsRouter); function browserDownloads(fileNamePrefix = "model") { return new BrowserDownloads(fileNamePrefix); } function browserFiles(files) { return new BrowserFiles(files); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) { checkPromises(promises); startFraction = startFraction == null ? 0 : startFraction; endFraction = endFraction == null ? 1 : endFraction; checkFraction(startFraction, endFraction); let resolvedPromise = 0; const registerMonitor = (promise) => { promise.then((value) => { const fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction); onProgress(fraction); return value; }); return promise; }; function checkPromises(promises2) { assert(promises2 != null && Array.isArray(promises2) && promises2.length > 0, () => "promises must be a none empty array"); } function checkFraction(startFraction2, endFraction2) { assert(startFraction2 >= 0 && startFraction2 <= 1, () => `Progress fraction must be in range [0, 1], but got startFraction ${startFraction2}`); assert(endFraction2 >= 0 && endFraction2 <= 1, () => `Progress fraction must be in range [0, 1], but got endFraction ${endFraction2}`); assert(endFraction2 >= startFraction2, () => `startFraction must be no more than endFraction, but got startFraction ${startFraction2} and endFraction ${endFraction2}`); } return Promise.all(promises.map(registerMonitor)); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ async function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { if (loadOptions == null) { loadOptions = {}; } const fetchFunc = loadOptions.fetchFunc == null ? env3().platform.fetch : loadOptions.fetchFunc; const requests = fetchURLs.map((fetchURL) => fetchFunc(fetchURL, loadOptions.requestInit, {isBinary: true})); const fetchStartFraction = 0; const fetchEndFraction = 0.5; const responses = loadOptions.onProgress == null ? await Promise.all(requests) : await monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction); const bufferPromises = responses.map((response) => response.arrayBuffer()); const bufferStartFraction = 0.5; const bufferEndFraction = 1; const buffers = loadOptions.onProgress == null ? await Promise.all(bufferPromises) : await monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction); return buffers; } async function loadWeights(manifest, filePathPrefix = "", weightNames, requestInit) { const fetchWeights = (fetchUrls) => loadWeightsAsArrayBuffer(fetchUrls, {requestInit}); const loadWeights2 = weightsLoaderFactory(fetchWeights); return loadWeights2(manifest, filePathPrefix, weightNames); } function weightsLoaderFactory(fetchWeightsFunction) { return async (manifest, filePathPrefix = "", weightNames) => { const groupIndicesToFetchMap = manifest.map(() => false); const groupWeightsToFetch = {}; const weightsFound = weightNames != null ? weightNames.map(() => false) : []; const allManifestWeightNames = []; manifest.forEach((manifestGroupConfig, groupIndex) => { let groupOffset = 0; manifestGroupConfig.weights.forEach((weightsEntry) => { const rawDtype = "quantization" in weightsEntry ? weightsEntry.quantization.dtype : weightsEntry.dtype; const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * sizeFromShape(weightsEntry.shape); const enqueueWeightsForFetchingFn = () => { groupIndicesToFetchMap[groupIndex] = true; if (groupWeightsToFetch[groupIndex] == null) { groupWeightsToFetch[groupIndex] = []; } groupWeightsToFetch[groupIndex].push({ manifestEntry: weightsEntry, groupOffset, sizeBytes: weightsBytes }); }; if (weightNames != null) { weightNames.forEach((weightName, weightIndex) => { if (weightName === weightsEntry.name) { enqueueWeightsForFetchingFn(); weightsFound[weightIndex] = true; } }); } else { enqueueWeightsForFetchingFn(); } allManifestWeightNames.push(weightsEntry.name); groupOffset += weightsBytes; }); }); if (!weightsFound.every((found) => found)) { const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]); throw new Error(`Could not find weights in manifest with names: ${weightsNotFound.join(", ")}. Manifest JSON has weights with names: ${allManifestWeightNames.join(", ")}.`); } const groupIndicesToFetch = groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => { if (shouldFetch) { accumulator.push(i); } return accumulator; }, []); const fetchUrls = []; groupIndicesToFetch.forEach((i) => { manifest[i].paths.forEach((filepath) => { const fetchUrl = filePathPrefix + (!filePathPrefix.endsWith("/") ? "/" : "") + filepath; fetchUrls.push(fetchUrl); }); }); const buffers = await fetchWeightsFunction(fetchUrls); const weightsTensorMap = {}; let bufferIndexOffset = 0; groupIndicesToFetch.forEach((i) => { const numBuffers = manifest[i].paths.length; let groupBytes = 0; for (let i2 = 0; i2 < numBuffers; i2++) { groupBytes += buffers[bufferIndexOffset + i2].byteLength; } const groupBuffer = new ArrayBuffer(groupBytes); const groupByteBuffer = new Uint8Array(groupBuffer); let groupBufferOffset = 0; for (let i2 = 0; i2 < numBuffers; i2++) { const buffer3 = new Uint8Array(buffers[bufferIndexOffset + i2]); groupByteBuffer.set(buffer3, groupBufferOffset); groupBufferOffset += buffer3.byteLength; } const weightsEntries = groupWeightsToFetch[i]; weightsEntries.forEach((weightsEntry) => { const byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); const nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); for (const name in nameToTensorMap) { weightsTensorMap[name] = nameToTensorMap[name]; } }); bufferIndexOffset += numBuffers; }); return weightsTensorMap; }; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const OCTET_STREAM_MIME_TYPE = "application/octet-stream"; const JSON_TYPE = "application/json"; class HTTPRequest { constructor(path, loadOptions) { this.DEFAULT_METHOD = "POST"; if (loadOptions == null) { loadOptions = {}; } this.weightPathPrefix = loadOptions.weightPathPrefix; this.onProgress = loadOptions.onProgress; this.weightUrlConverter = loadOptions.weightUrlConverter; if (loadOptions.fetchFunc != null) { assert(typeof loadOptions.fetchFunc === "function", () => "Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)"); this.fetch = loadOptions.fetchFunc; } else { this.fetch = env3().platform.fetch; } assert(path != null && path.length > 0, () => "URL path for http must not be null, undefined or empty."); if (Array.isArray(path)) { assert(path.length === 2, () => `URL paths for http must have a length of 2, (actual length is ${path.length}).`); } this.path = path; if (loadOptions.requestInit != null && loadOptions.requestInit.body != null) { throw new Error("requestInit is expected to have no pre-existing body, but has one."); } this.requestInit = loadOptions.requestInit || {}; } async save(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error("BrowserHTTPRequest.save() does not support saving model topology in binary formats yet."); } const init2 = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit); init2.body = new FormData(); const weightsManifest = [{ paths: ["./model.weights.bin"], weights: modelArtifacts.weightSpecs }]; const modelTopologyAndWeightManifest = { modelTopology: modelArtifacts.modelTopology, format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy, userDefinedMetadata: modelArtifacts.userDefinedMetadata, weightsManifest }; init2.body.append("model.json", new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {type: JSON_TYPE}), "model.json"); if (modelArtifacts.weightData != null) { init2.body.append("model.weights.bin", new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}), "model.weights.bin"); } const response = await this.fetch(this.path, init2); if (response.ok) { return { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts), responses: [response] }; } else { throw new Error(`BrowserHTTPRequest.save() failed due to HTTP response status ${response.status}.`); } } async load() { const modelConfigRequest = await this.fetch(this.path, this.requestInit); if (!modelConfigRequest.ok) { throw new Error(`Request to ${this.path} failed with status code ${modelConfigRequest.status}. Please verify this URL points to the model JSON of the model to load.`); } let modelConfig; try { modelConfig = await modelConfigRequest.json(); } catch (e) { let message = `Failed to parse model JSON of response from ${this.path}.`; if (this.path.endsWith(".pb")) { message += " Your path contains a .pb file extension. Support for .pb models have been removed in TensorFlow.js 1.0 in favor of .json models. You can re-convert your Python TensorFlow model using the TensorFlow.js 1.0 conversion scripts or you can convert your.pb models with the 'pb2json'NPM script in the tensorflow/tfjs-converter repository."; } else { message += " Please make sure the server is serving valid JSON for this request."; } throw new Error(message); } const modelTopology = modelConfig.modelTopology; const weightsManifest = modelConfig.weightsManifest; const generatedBy = modelConfig.generatedBy; const convertedBy = modelConfig.convertedBy; const format = modelConfig.format; const userDefinedMetadata = modelConfig.userDefinedMetadata; if (modelTopology == null && weightsManifest == null) { throw new Error(`The JSON from HTTP path ${this.path} contains neither model topology or manifest for weights.`); } let weightSpecs; let weightData; if (weightsManifest != null) { const results = await this.loadWeights(weightsManifest); [weightSpecs, weightData] = results; } const artifacts = { modelTopology, weightSpecs, weightData, userDefinedMetadata, generatedBy, convertedBy, format }; const initializer = modelConfig.modelInitializer; if (initializer) { artifacts.modelInitializer = initializer; } return artifacts; } async loadWeights(weightsManifest) { const weightPath = Array.isArray(this.path) ? this.path[1] : this.path; const [prefix, suffix] = parseUrl(weightPath); const pathPrefix = this.weightPathPrefix || prefix; const weightSpecs = []; for (const entry of weightsManifest) { weightSpecs.push(...entry.weights); } const fetchURLs = []; const urlPromises = []; for (const weightsGroup of weightsManifest) { for (const path of weightsGroup.paths) { if (this.weightUrlConverter != null) { urlPromises.push(this.weightUrlConverter(path)); } else { fetchURLs.push(pathPrefix + path + suffix); } } } if (this.weightUrlConverter) { fetchURLs.push(...await Promise.all(urlPromises)); } const buffers = await loadWeightsAsArrayBuffer(fetchURLs, { requestInit: this.requestInit, fetchFunc: this.fetch, onProgress: this.onProgress }); return [weightSpecs, concatenateArrayBuffers(buffers)]; } } HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//; function parseUrl(url) { const lastSlash = url.lastIndexOf("/"); const lastSearchParam = url.lastIndexOf("?"); const prefix = url.substring(0, lastSlash); const suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : ""; return [prefix + "/", suffix]; } function isHTTPScheme(url) { return url.match(HTTPRequest.URL_SCHEME_REGEX) != null; } const httpRouter = (url, loadOptions) => { if (typeof fetch === "undefined" && (loadOptions == null || loadOptions.fetchFunc == null)) { return null; } else { let isHTTP = true; if (Array.isArray(url)) { isHTTP = url.every((urlItem) => isHTTPScheme(urlItem)); } else { isHTTP = isHTTPScheme(url); } if (isHTTP) { return http(url, loadOptions); } } return null; }; IORouterRegistry.registerSaveRouter(httpRouter); IORouterRegistry.registerLoadRouter(httpRouter); function http(path, loadOptions) { return new HTTPRequest(path, loadOptions); } function browserHTTPRequest(path, loadOptions) { return http(path, loadOptions); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class PassthroughLoader { constructor(modelArtifacts) { this.modelArtifacts = modelArtifacts; } async load() { return this.modelArtifacts; } } class PassthroughSaver { constructor(saveHandler) { this.saveHandler = saveHandler; } async save(modelArtifacts) { return this.saveHandler(modelArtifacts); } } function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) { if (arguments.length === 1) { const isModelArtifacts = modelArtifacts.modelTopology != null || modelArtifacts.weightSpecs != null; if (isModelArtifacts) { return new PassthroughLoader(modelArtifacts); } else { console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release."); return new PassthroughLoader({modelTopology: modelArtifacts}); } } else { console.warn("Please call tf.io.fromMemory() with only one argument. The argument should be of type ModelArtifacts. The multi-argument signature of tf.io.fromMemory() has been deprecated and will be removed in a future release."); return new PassthroughLoader({ modelTopology: modelArtifacts, weightSpecs, weightData, trainingConfig }); } } function withSaveHandler(saveHandler) { return new PassthroughSaver(saveHandler); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var io = /* @__PURE__ */ Object.freeze({ __proto__: null, browserFiles, browserHTTPRequest, concatenateArrayBuffers, decodeWeights, encodeWeights, fromMemory, getLoadHandlers, getModelArtifactsInfoForJSON, getSaveHandlers, http, isHTTPScheme, loadWeights, registerLoadRouter, registerSaveRouter, weightsLoaderFactory, withSaveHandler, copyModel, listModels, moveModel, removeModel }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function reshape_(x, shape) { const $x = convertToTensor(x, "x", "reshape", null); const inputs = {x: $x}; const attrs = {shape}; const forward = (backend2, save) => { shape = inferFromImplicitShape(shape, $x.size); assert($x.size === sizeFromShape(shape), () => "new shape and old shape must have the same number of elements."); save([$x]); return backend2.reshape($x, shape); }; return ENGINE.runKernelFunc(forward, inputs, null, Reshape6, attrs); } const reshape2 = op({reshape_}); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function matMul_(a, b, transposeA = false, transposeB = false) { let $a = convertToTensor(a, "a", "matMul"); let $b = convertToTensor(b, "b", "matMul"); [$a, $b] = makeTypesMatch($a, $b); const forward = (backend2, save) => { save([$a, $b]); const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; const outerDimsA = $a.shape.slice(0, -2); const outerDimsB = $b.shape.slice(0, -2); const batchDimA = sizeFromShape(outerDimsA); const batchDimB = sizeFromShape(outerDimsB); const batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1; assert($a.rank >= 2 && $b.rank >= 2 && batchDimsCompatible, () => `Error in matMul: the input batch dimensions must either be the same or at least one input batch dimension must be 1. Got input batch dimensions of (${outerDimsA}) and (${outerDimsB}).`); assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (${innerShapeB}) of Tensors with shapes ${$a.shape} and ${$b.shape} and transposeA=${transposeA} and transposeB=${transposeB} must match.`); const outShapeOuterDims = batchDimA > batchDimB ? outerDimsA : outerDimsB; const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); const a3D = transposeA ? reshape2($a, [batchDimA, innerShapeA, outerShapeA]) : reshape2($a, [batchDimA, outerShapeA, innerShapeA]); const b3D = transposeB ? reshape2($b, [batchDimB, outerShapeB, innerShapeB]) : reshape2($b, [batchDimB, innerShapeB, outerShapeB]); const res3d = backend2.batchMatMul(a3D, b3D, transposeA, transposeB); return reshape2(res3d, outShape); }; const inputs = {a: $a, b: $b}; const attrs = {transposeA, transposeB}; return ENGINE.runKernelFunc(forward, inputs, null, BatchMatMul3, attrs); } const matMul = op({matMul_}); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function oneHot_(indices, depth, onValue = 1, offValue = 0) { if (depth < 2) { throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`); } const $indices = convertToTensor(indices, "indices", "oneHot", "int32"); const outShape = [...$indices.shape, depth]; const forward = (backend2, save) => { save([$indices]); return reshape2(backend2.oneHot(reshape2($indices, [$indices.size]), depth, onValue, offValue), outShape); }; const inputs = {indices: $indices}; const attrs = {depth, onValue, offValue}; return ENGINE.runKernelFunc(forward, inputs, null, OneHot3, attrs); } const oneHot2 = op({oneHot_}); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function transpose_(x, perm) { const $x = convertToTensor(x, "x", "transpose"); if (perm == null) { perm = $x.shape.map((s, i) => i).reverse(); } assert($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} must match length of perm ${perm}.`); perm.forEach((axis) => { assert(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1} but got ${perm}`); }); if ($x.rank <= 1) { return $x.clone(); } const inputs = {x: $x}; const attrs = {perm}; return ENGINE.runKernelFunc((backend2) => backend2.transpose($x, perm), inputs, null, Transpose5, attrs); } const transpose2 = op({transpose_}); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function confusionMatrix_(labels, predictions, numClasses) { const $labels = convertToTensor(labels, "labels", "confusionMatrix"); const $predictions = convertToTensor(predictions, "predictions", "confusionMatrix"); assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), () => `If provided, numClasses must be a positive integer, but got ${numClasses}`); assert($labels.rank === 1, () => `Expected the rank of labels to be 1, but got ${$labels.rank}`); assert($predictions.rank === 1, () => `Expected the rank of predictions to be 1, but got ${$predictions.rank}`); assert($labels.shape[0] === $predictions.shape[0], () => `Mismatch in the number of examples: ${$labels.shape[0]} vs. ${$predictions.shape[0]}. Labels and predictions should have the same number of elements.`); assert(numClasses > 0 && Number.isInteger(numClasses), () => `numClasses is required to be a positive integer, but got ${numClasses}`); const oneHotLabels = oneHot2(cast2($labels, "int32"), numClasses); const oneHotPredictions = oneHot2(cast2($predictions, "int32"), numClasses); const oneHotLabelsT = transpose2(oneHotLabels); const product = matMul(oneHotLabelsT, oneHotPredictions); return cast2(product, "int32"); } const confusionMatrix = op({confusionMatrix_}); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var math = /* @__PURE__ */ Object.freeze({ __proto__: null, confusionMatrix }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function tensor3d(values, shape, dtype) { assertNonNull(values); if (shape != null && shape.length !== 3) { throw new Error("tensor3d() requires shape to have three numbers"); } const inferredShape = inferShape(values, dtype); if (inferredShape.length !== 3 && inferredShape.length !== 1) { throw new Error("tensor3d() requires values to be number[][][] or flat/TypedArray"); } if (inferredShape.length === 1 && shape == null) { throw new Error("tensor3d() requires shape to be provided when `values` are a flat array"); } return makeTensor(values, shape, inferredShape, dtype); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let fromPixels2DContext; function fromPixels_(pixels, numChannels = 3) { if (numChannels > 4) { throw new Error("Cannot construct Tensor with more than 4 channels from pixels."); } if (pixels == null) { throw new Error("pixels passed to tf.browser.fromPixels() can not be null"); } let isPixelData = false; let isImageData = false; let isVideo = false; let isImage = false; let isCanvasLike = false; if (pixels.data instanceof Uint8Array) { isPixelData = true; } else if (typeof ImageData !== "undefined" && pixels instanceof ImageData) { isImageData = true; } else if (typeof HTMLVideoElement !== "undefined" && pixels instanceof HTMLVideoElement) { isVideo = true; } else if (typeof HTMLImageElement !== "undefined" && pixels instanceof HTMLImageElement) { isImage = true; } else if (pixels.getContext != null) { isCanvasLike = true; } else { throw new Error(`pixels passed to tf.browser.fromPixels() must be either an HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData in browser, or OffscreenCanvas, ImageData in webworker or {data: Uint32Array, width: number, height: number}, but was ${pixels.constructor.name}`); } if (isVideo) { const HAVE_CURRENT_DATA_READY_STATE = 2; if (isVideo && pixels.readyState < HAVE_CURRENT_DATA_READY_STATE) { throw new Error("The video element has not loaded data yet. Please wait for `loadeddata` event on the