diff --git a/assets/quicksettings.js b/assets/quicksettings.js deleted file mode 100644 index 43cc3381..00000000 --- a/assets/quicksettings.js +++ /dev/null @@ -1 +0,0 @@ -!function(){function a(a,b){var d=c("div",null,"qs_label",b);return d.innerHTML=a,d}function b(a,b,d,e){var f=c("input",b,d,e);return f.type=a,f}function c(a,b,c,d){var e=document.createElement(a);if(e)return e.id=b,c&&(e.className=c),d&&d.appendChild(e),e}function d(){return navigator.userAgent.indexOf("rv:11")!=-1||navigator.userAgent.indexOf("MSIE")!=-1}function e(){var a=navigator.userAgent.toLowerCase();return!(a.indexOf("chrome")>-1||a.indexOf("firefox")>-1||a.indexOf("epiphany")>-1)&&a.indexOf("safari/")>-1}function f(){var a=navigator.userAgent.toLowerCase();return a.indexOf("edge")>-1}function g(){var a=document.createElement("style");a.innerText=i,document.head.appendChild(a),h=!0}var h=!1,i=".qs_main{background-color:#dddddd;text-align:left;position:absolute;width:200px;font:12px sans-serif;box-shadow:5px 5px 8px rgba(0,0,0,0.35);user-select:none;-webkit-user-select:none;color:#000000;border:none}.qs_content{background-color:#cccccc;overflow-y:auto}.qs_title_bar{background-color:#eeeeee;user-select:none;-webkit-user-select:none;cursor:pointer;padding:5px;font-weight:bold;border:none;color:#000000}.qs_container{margin:5px;padding:5px;background-color:#eeeeee;border:none;position:relative}.qs_container_selected{border:none;background-color:#ffffff}.qs_range{-webkit-appearance:none;-moz-appearance:none;width:100%;height:17px;padding:0;margin:0;background-color:transparent;border:none;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}.qs_range:focus{outline:none;border:none}.qs_range::-webkit-slider-runnable-track{width:100%;height:15px;cursor:pointer;background:#cccccc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0}.qs_range:focus::-webkit-slider-runnable-track{background:#cccccc}.qs_range::-webkit-slider-thumb{-webkit-appearance:none;height:15px;width:15px;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;background:#999999;cursor:pointer;margin-top:0}.qs_range::-moz-range-track{width:100%;height:15px;cursor:pointer;background:#cccccc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0}.qs_range::-moz-range-thumb{height:15px;width:15px;border:none;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;background:#999999;cursor:pointer}.qs_range::-ms-track{width:100%;height:15px;cursor:pointer;visibility:hidden;background:transparent}.qs_range::-ms-thumb{height:15px;width:15px;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;background:#999999;cursor:pointer;border:none}.qs_range::-ms-fill-lower{background:#cccccc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0}.qs_range:focus::-ms-fill-lower{background:#cccccc}.qs_range::-ms-fill-upper{background:#cccccc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0}.qs_range:focus::-ms-fill-upper{background:#cccccc}.qs_button{background-color:#f6f6f6;color:#000000;height:30px;border:1px solid #aaaaaa;font:12px sans-serif}.qs_button:active{background-color:#ffffff;border:1px solid #aaaaaa}.qs_button:focus{border:1px solid #aaaaaa;outline:none}.qs_checkbox{cursor:pointer}.qs_checkbox input{position:absolute;left:-99999px}.qs_checkbox span{height:16px;width:100%;display:block;text-indent:20px;background:url('') no-repeat}.qs_checkbox input:checked+span{background:url('') no-repeat}.qs_checkbox_label{position:absolute;top:7px;left:30px}.qs_label{margin-bottom:3px;user-select:none;-webkit-user-select:none;cursor:default;font:12px sans-serif}.qs_text_input{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;width:100%;padding:0 0 0 5px;height:24px;border:1px inset #ffffff;background-color:#ffffff;color:#000000;font-size:12px}.qs_text_input:focus{outline:none;background:#ffffff;border:1px inset #ffffff}.qs_select{background:url('') no-repeat right #f6f6f6;-webkit-appearance:none;-moz-appearance:none;appearance:none;color:#000000;width:100%;height:24px;border:1px solid #aaaaaa;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;padding:0 5px;-moz-outline:none;font-size:14px}.qs_select option{font-size:14px}.qs_select::-ms-expand{display:none}.qs_select:focus{outline:none}.qs_number{height:24px}.qs_image{width:100%}.qs_progress{width:100%;height:15px;background-color:#cccccc;border:none;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}.qs_progress_value{height:100%;background-color:#999999}.qs_textarea{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;resize:vertical;width:100%;padding:3px 5px;border:1px inset #ffffff;background-color:#ffffff;color:#000000;font-size:12px}.qs_textarea:focus{outline:none;background:#ffffff;border:1px inset #ffffff}.qs_color{position:absolute;left:-999999px}.qs_color_label{width:100%;height:20px;display:block;border:1px solid #aaaaaa;cursor:pointer;padding:0 0 0 5px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}.qs_file_chooser{position:absolute;left:-999999px}.qs_file_chooser_label{background-color:#f6f6f6;color:#000000;height:30px;border:1px solid #aaaaaa;font:12px sans-serif;width:100%;display:block;cursor:pointer;padding:7px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;white-space:nowrap;overflow:hidden;text-overflow:ellipsis}",j={_version:"2.1",_topZ:1,_panel:null,_titleBar:null,_content:null,_startX:0,_startY:0,_hidden:!1,_collapsed:!1,_controls:null,_keyCode:-1,_draggable:!0,_collapsible:!0,_globalChangeHandler:null,useExtStyleSheet:function(){h=!0},create:function(a,b,c,d){var e=Object.create(this);return e._init(a,b,c,d),e},destroy:function(){this._panel.parentElement&&this._panel.parentElement.removeChild(this._panel);for(var a in this)this[a]=null},_init:function(a,b,c,d){h||g(),this._bindHandlers(),this._createPanel(a,b,d),this._createTitleBar(c||"QuickSettings"),this._createContent()},_bindHandlers:function(){this._startDrag=this._startDrag.bind(this),this._drag=this._drag.bind(this),this._endDrag=this._endDrag.bind(this),this._doubleClickTitle=this._doubleClickTitle.bind(this),this._onKeyUp=this._onKeyUp.bind(this)},getValuesAsJSON:function(a){var b={};for(var c in this._controls)this._controls[c].getValue&&(b[c]=this._controls[c].getValue());return a&&(b=JSON.stringify(b)),b},setValuesFromJSON:function(a){"string"==typeof a&&(a=JSON.parse(a));for(var b in a)this._controls[b].setValue&&this._controls[b].setValue(a[b]);return this},saveInLocalStorage:function(a){return this._localStorageName=a,this._readFromLocalStorage(a),this},clearLocalStorage:function(a){return localStorage.removeItem(a),this},_saveInLocalStorage:function(a){localStorage.setItem(a,this.getValuesAsJSON(!0))},_readFromLocalStorage:function(a){var b=localStorage.getItem(a);b&&this.setValuesFromJSON(b)},_createPanel:function(a,b,d){this._panel=c("div",null,"qs_main",d||document.body),this._panel.style.zIndex=++j._topZ,this.setPosition(a||0,b||0),this._controls={}},_createTitleBar:function(a){this._titleBar=c("div",null,"qs_title_bar",this._panel),this._titleBar.textContent=a,this._titleBar.addEventListener("mousedown",this._startDrag),this._titleBar.addEventListener("dblclick",this._doubleClickTitle)},_createContent:function(){this._content=c("div",null,"qs_content",this._panel)},_createContainer:function(){var a=c("div",null,"qs_container");return a.addEventListener("focus",function(){this.className+=" qs_container_selected"},!0),a.addEventListener("blur",function(){var a=this.className.indexOf(" qs_container_selected");a>-1&&(this.className=this.className.substr(0,a))},!0),this._content.appendChild(a),a},setPosition:function(a,b){return this._panel.style.left=a+"px",this._panel.style.top=Math.max(b,0)+"px",this},setSize:function(a,b){return this._panel.style.width=a+"px",this._content.style.width=a+"px",this._content.style.height=b-this._titleBar.offsetHeight+"px",this},setWidth:function(a){return this._panel.style.width=a+"px",this._content.style.width=a+"px",this},setHeight:function(a){return this._content.style.height=a-this._titleBar.offsetHeight+"px",this},setDraggable:function(a){return this._draggable=a,this._draggable||this._collapsible?this._titleBar.style.cursor="pointer":this._titleBar.style.cursor="default",this},_startDrag:function(a){this._draggable&&(this._panel.style.zIndex=++j._topZ,document.addEventListener("mousemove",this._drag),document.addEventListener("mouseup",this._endDrag),this._startX=a.clientX,this._startY=a.clientY),a.preventDefault()},_drag:function(a){var b=parseInt(this._panel.style.left),c=parseInt(this._panel.style.top),d=a.clientX,e=a.clientY;this.setPosition(b+d-this._startX,c+e-this._startY),this._startX=d,this._startY=e,a.preventDefault()},_endDrag:function(a){document.removeEventListener("mousemove",this._drag),document.removeEventListener("mouseup",this._endDrag),a.preventDefault()},setGlobalChangeHandler:function(a){return this._globalChangeHandler=a,this},_callGCH:function(){this._localStorageName&&this._saveInLocalStorage(this._localStorageName),this._globalChangeHandler&&this._globalChangeHandler()},hide:function(){return this._panel.style.visibility="hidden",this._hidden=!0,this},show:function(){return this._panel.style.visibility="visible",this._panel.style.zIndex=++j._topZ,this._hidden=!1,this},toggleVisibility:function(){return this._hidden?this.show():this.hide(),this},setCollapsible:function(a){return this._collapsible=a,this._draggable||this._collapsible?this._titleBar.style.cursor="pointer":this._titleBar.style.cursor="default",this},collapse:function(){return this._panel.removeChild(this._content),this._collapsed=!0,this},expand:function(){return this._panel.appendChild(this._content),this._collapsed=!1,this},toggleCollapsed:function(){return this._collapsed?this.expand():this.collapse(),this},setKey:function(a){return this._keyCode=a.toUpperCase().charCodeAt(0),document.body.addEventListener("keyup",this.onKeyUp),this},_onKeyUp:function(a){a.keyCode===this._keyCode&&this.toggleVisibility()},_doubleClickTitle:function(){this._collapsible&&this.toggleCollapsed()},removeControl:function(a){if(this._controls[a])var b=this._controls[a].container;return b&&b.parentElement&&b.parentElement.removeChild(b),this._controls[a]=null,this},enableControl:function(a){return this._controls[a]&&(this._controls[a].control.disabled=!1),this},disableControl:function(a){return this._controls[a]&&(this._controls[a].control.disabled=!0),this},hideControl:function(a){return this._controls[a]&&(this._controls[a].container.style.display="none"),this},showControl:function(a){return this._controls[a]&&(this._controls[a].container.style.display="block"),this},overrideStyle:function(a,b,c){return this._controls[a]&&(this._controls[a].control.style[b]=c),this},hideTitle:function(a){var b=this._controls[a].label;return b&&(b.style.display="none"),this},showTitle:function(a){var b=this._controls[a].label;return b&&(b.style.display="block"),this},hideAllTitles:function(){for(var a in this._controls){var b=this._controls[a].label;b&&(b.style.display="none")}return this},showAllTitles:function(){for(var a in this._controls){var b=this._controls[a].label;b&&(b.style.display="block")}return this},getValue:function(a){return this._controls[a].getValue()},setValue:function(a,b){return this._controls[a].setValue(b),this._callGCH(),this},addBoolean:function(a,d,e){var f=this._createContainer(),g=c("label",null,"qs_checkbox_label",f);g.textContent=a,g.setAttribute("for",a);var h=c("label",null,"qs_checkbox",f);h.setAttribute("for",a);var i=b("checkbox",a,null,h);i.checked=d;c("span",null,null,h);this._controls[a]={container:f,control:i,getValue:function(){return this.control.checked},setValue:function(a){this.control.checked=a,e&&e(a)}};var j=this;return i.addEventListener("change",function(){e&&e(i.checked),j._callGCH()}),this},bindBoolean:function(a,b,c){return this.addBoolean(a,b,function(b){c[a]=b})},addButton:function(a,c){var d=this._createContainer(),e=b("button",a,"qs_button",d);e.value=a,this._controls[a]={container:d,control:e};var f=this;return e.addEventListener("click",function(){c&&c(e),f._callGCH()}),this},addColor:function(g,h,i){if(e()||f()||d())return this.addText(g,h,i);var j=this._createContainer(),k=a(""+g+": "+h,j),l=b("color",g,"qs_color",j);l.value=h||"#ff0000";var m=c("label",null,"qs_color_label",j);m.setAttribute("for",g),m.style.backgroundColor=l.value,this._controls[g]={container:j,control:l,colorLabel:m,label:k,title:g,getValue:function(){return this.control.value},setValue:function(a){this.control.value=a,this.colorLabel.style.backgroundColor=l.value,this.label.innerHTML=""+this.title+": "+this.control.value,i&&i(a)}};var n=this;return l.addEventListener("input",function(){k.innerHTML=""+g+": "+l.value,m.style.backgroundColor=l.value,i&&i(l.value),n._callGCH()}),this},bindColor:function(a,b,c){return this.addColor(a,b,function(b){c[a]=b})},addDate:function(c,e,f){var g;if(e instanceof Date){var h=e.getFullYear(),i=e.getMonth()+1;i<10&&(i="0"+i);var j=e.getDate();g=h+"-"+i+"-"+j}else g=e;if(d())return this.addText(c,g,f);var k=this._createContainer(),l=a(""+c+"",k),m=b("date",c,"qs_text_input",k);m.value=g||"",this._controls[c]={container:k,control:m,label:l,getValue:function(){return this.control.value},setValue:function(a){var b;if(a instanceof Date){var c=a.getFullYear(),d=a.getMonth()+1;d<10&&(d="0"+d);var e=a.getDate();e<10&&(e="0"+e),b=c+"-"+d+"-"+e}else b=a;this.control.value=b||"",f&&f(b)}};var n=this;return m.addEventListener("input",function(){f&&f(m.value),n._callGCH()}),this},bindDate:function(a,b,c){return this.addDate(a,b,function(b){c[a]=b})},addDropDown:function(b,d,e){for(var f=this._createContainer(),g=a(""+b+"",f),h=c("select",null,"qs_select",f),i=0;i"+b+"",d);return d.appendChild(c),this._controls[b]={container:d,label:e},this},addFileChooser:function(d,e,f,g){var h=this._createContainer(),i=a(""+d+"",h),j=b("file",d,"qs_file_chooser",h);f&&(j.accept=f);var k=c("label",null,"qs_file_chooser_label",h);k.setAttribute("for",d),k.textContent=e||"Choose a file...",this._controls[d]={container:h,control:j,label:i,getValue:function(){return this.control.files[0]}};var l=this;return j.addEventListener("change",function(){j.files&&j.files.length&&(k.textContent=j.files[0].name,g&&g(j.files[0]),l._callGCH())}),this},addHTML:function(b,d){var e=this._createContainer(),f=a(""+b+": ",e),g=c("div",null,null,e);return g.innerHTML=d,this._controls[b]={label:f,control:g,getValue:function(){return this.control.innerHTML},setValue:function(a){this.control.innerHTML=a}},this},addImage:function(b,d){var e=this._createContainer(),f=a(""+b+"",e);return img=c("img",null,"qs_image",e),img.src=d,this._controls[b]={container:e,control:img,label:f,getValue:function(){return this.control.src},setValue:function(a){this.control.src=a}},this},addRange:function(a,b,c,d,e,f){return this._addNumber("range",a,b,c,d,e,f)},addNumber:function(a,b,c,d,e,f){return this._addNumber("number",a,b,c,d,e,f)},_addNumber:function(c,e,f,g,h,i,j){var k=this._createContainer(),l=a("",k),m="range"===c?"qs_range":"qs_text_input qs_number",n=b(c,e,m,k);n.min=f||0,n.max=g||100,n.step=i||1,n.value=h||0,l.innerHTML=""+e+": "+n.value,this._controls[e]={container:k,control:n,label:l,title:e,callback:j,getValue:function(){return parseFloat(this.control.value)},setValue:function(a){this.control.value=a,this.label.innerHTML=""+this.title+": "+this.control.value,j&&j(parseFloat(a))}};var o="input";"range"===c&&d()&&(o="change");var p=this;return n.addEventListener(o,function(){l.innerHTML=""+e+": "+n.value,j&&j(parseFloat(n.value)),p._callGCH()}),this},bindRange:function(a,b,c,d,e,f){return this.addRange(a,b,c,d,e,function(b){f[a]=b})},bindNumber:function(a,b,c,d,e,f){return this.addNumber(a,b,c,d,e,function(b){f[a]=b})},setRangeParameters:function(a,b,c,d){return this.setNumberParameters(a,b,c,d)},setNumberParameters:function(a,b,c,d){var e=this._controls[a],f=e.control.value;return e.control.min=b,e.control.max=c,e.control.step=d,e.control.value!==f&&e.callback&&e.callback(e.control.value),this},addPassword:function(a,b,c){return this._addText("password",a,b,c)},bindPassword:function(a,b,c){return this.addPassword(a,b,function(b){c[a]=b})},addProgressBar:function(b,d,e,f){var g=this._createContainer(),h=a("",g),i=c("div",null,"qs_progress",g),j=c("div",null,"qs_progress_value",i);return j.style.width=e/d*100+"%","numbers"===f?h.innerHTML=""+b+": "+e+" / "+d:"percent"===f?h.innerHTML=""+b+": "+Math.round(e/d*100)+"%":h.innerHTML=""+b+"",this._controls[b]={container:g,control:i,valueDiv:j,valueDisplay:f,label:h,value:e,max:d,title:b,getValue:function(){return this.value},setValue:function(a){this.value=Math.max(0,Math.min(a,this.max)),this.valueDiv.style.width=this.value/this.max*100+"%","numbers"===this.valueDisplay?this.label.innerHTML=""+this.title+": "+this.value+" / "+this.max:"percent"===this.valueDisplay&&(this.label.innerHTML=""+this.title+": "+Math.round(this.value/this.max*100)+"%")}},this},setProgressMax:function(a,b){var c=this._controls[a];return c.max=b,c.value=Math.min(c.value,c.max),c.valueDiv.style.width=c.value/c.max*100+"%","numbers"===c.valueDisplay?c.label.innerHTML=""+c.title+": "+c.value+" / "+c.max:"percent"===c.valueDisplay?c.label.innerHTML=""+c.title+": "+Math.round(c.value/c.max*100)+"%":c.label.innerHTML=""+c.title+"",this},addText:function(a,b,c){return this._addText("text",a,b,c)},_addText:function(d,e,f,g){var h,i=this._createContainer(),j=a(""+e+"",i);"textarea"===d?(h=c("textarea",e,"qs_textarea",i),h.rows=5):h=b(d,e,"qs_text_input",i),h.value=f||"",this._controls[e]={container:i,control:h,label:j,getValue:function(){return this.control.value},setValue:function(a){this.control.value=a,g&&g(a)}};var k=this;return h.addEventListener("input",function(){g&&g(h.value),k._callGCH()}),this},bindText:function(a,b,c){return this.addText(a,b,function(b){c[a]=b})},addTextArea:function(a,b,c){return this._addText("textarea",a,b,c)},setTextAreaRows:function(a,b){return this._controls[a].control.rows=b,this},bindTextArea:function(a,b,c){return this.addTextArea(a,b,function(b){c[a]=b})},addTime:function(c,e,f){var g;if(e instanceof Date){var h=e.getHours();h<10&&(h="0"+h);var i=e.getMinutes();i<10&&(i="0"+i);var j=e.getSeconds();j<10&&(j="0"+j),g=h+":"+i+":"+j}else g=e;if(d())return this.addText(c,g,f);var k=this._createContainer(),l=a(""+c+"",k),m=b("time",c,"qs_text_input",k);m.value=g||"",this._controls[c]={container:k,control:m,label:l,getValue:function(){return this.control.value},setValue:function(a){var b;if(a instanceof Date){var c=a.getHours();c<10&&(c="0"+c);var d=a.getMinutes();d<10&&(d="0"+d);var e=a.getSeconds();e<10&&(e="0"+e),b=c+":"+d+":"+e}else b=a;this.control.value=b||"",f&&f(b)}};var n=this;return m.addEventListener("input",function(){f&&f(m.value),n._callGCH()}),this},bindTime:function(a,b,c){return this.addTime(a,b,function(b){c[a]=b})}};"object"==typeof exports&&"object"==typeof module?module.exports=j:"function"==typeof define&&define.amd?define(j):window.QuickSettings=j}(); \ No newline at end of file diff --git a/assets/tf.esnext.js b/assets/tf.esnext.js deleted file mode 100644 index 1226e06f..00000000 --- a/assets/tf.esnext.js +++ /dev/null @@ -1,82571 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ -(function (global, factory) { - typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) : - typeof define === 'function' && define.amd ? define(['exports'], factory) : - (global = global || self, factory(global.tf = global.tf || {})); -}(this, (function (exports) { 'use strict'; - - /** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const EPSILON_FLOAT32 = 1e-7; - const EPSILON_FLOAT16 = 1e-4; - /** Convenient class for storing tensor-related data. */ - class DataStorage { - constructor(backend, dataMover) { - this.backend = backend; - this.dataMover = dataMover; - this.data = new WeakMap(); - this.dataIdsCount = 0; - } - get(dataId) { - if (!this.data.has(dataId)) { - this.dataMover.moveData(this.backend, dataId); - } - return this.data.get(dataId); - } - set(dataId, value) { - this.dataIdsCount++; - this.data.set(dataId, value); - } - has(dataId) { - return this.data.has(dataId); - } - delete(dataId) { - this.dataIdsCount--; - return this.data.delete(dataId); - } - numDataIds() { - return this.dataIdsCount; - } - } - /** - * The interface that defines the kernels that should be implemented when - * adding a new backend. New backends don't need to implement every one of the - * methods, this can be done gradually (throw an error for unimplemented - * methods). - */ - class KernelBackend { - time(f) { - return notYetImplemented('time'); - } - read(dataId) { - return notYetImplemented('read'); - } - readSync(dataId) { - return notYetImplemented('readSync'); - } - numDataIds() { - return notYetImplemented('numDataIds'); - } - disposeData(dataId) { - return notYetImplemented('disposeData'); - } - write(values, shape, dtype) { - return notYetImplemented('write'); - } - move(dataId, values, shape, dtype) { - return notYetImplemented('move'); - } - memory() { - return notYetImplemented('memory'); - } - /** Returns the highest precision for floats in bits (e.g. 16 or 32) */ - floatPrecision() { - return notYetImplemented('floatPrecision'); - } - /** Returns the smallest representable number. */ - epsilon() { - return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; - } - batchMatMul(a, b, transposeA, transposeB) { - return notYetImplemented('batchMatMul'); - } - fusedBatchMatMul({ a, b, transposeA, transposeB, bias, activation, preluActivationWeights }) { - return notYetImplemented('fusedBatchMatMul'); - } - slice(x, begin, size) { - return notYetImplemented('slice'); - } - stridedSlice(x, begin, end, strides) { - return notYetImplemented('stridedSlice'); - } - unstack(x, axis) { - return notYetImplemented('unstack'); - } - reverse(a, axis) { - return notYetImplemented('reverse'); - } - concat(tensors, axis) { - return notYetImplemented('concat'); - } - neg(a) { - return notYetImplemented('neg'); - } - add(a, b) { - return notYetImplemented('add'); - } - addN(tensors) { - return notYetImplemented('addN'); - } - subtract(a, b) { - return notYetImplemented('subtract'); - } - multiply(a, b) { - return notYetImplemented('multiply'); - } - realDivide(a, b) { - return notYetImplemented('realDivide'); - } - floorDiv(a, b) { - return notYetImplemented('floorDiv'); - } - sum(x, axes) { - return notYetImplemented('sum'); - } - prod(x, axes) { - return notYetImplemented('prod'); - } - unsortedSegmentSum(x, segmentIds, numSegments) { - return notYetImplemented('unsortedSegmentSum'); - } - argMin(x, axis) { - return notYetImplemented('argMin'); - } - argMax(x, axis) { - return notYetImplemented('argMax'); - } - equal(a, b) { - return notYetImplemented('equal'); - } - notEqual(a, b) { - return notYetImplemented('notEqual'); - } - less(a, b) { - return notYetImplemented('less'); - } - lessEqual(a, b) { - return notYetImplemented('lessEqual'); - } - greater(a, b) { - return notYetImplemented('greater'); - } - greaterEqual(a, b) { - return notYetImplemented('greaterEqual'); - } - logicalNot(a) { - return notYetImplemented('logicalNot'); - } - logicalAnd(a, b) { - return notYetImplemented('logicalAnd'); - } - logicalOr(a, b) { - return notYetImplemented('logicalOr'); - } - where(condition) { - return notYetImplemented('where'); - } - select(condition, a, b) { - return notYetImplemented('select'); - } - topk(x, k, sorted) { - return notYetImplemented('topk'); - } - min(x, axes) { - return notYetImplemented('min'); - } - minimum(a, b) { - return notYetImplemented('minimum'); - } - mod(a, b) { - return notYetImplemented('mod'); - } - max(x, axes) { - return notYetImplemented('max'); - } - maximum(a, b) { - return notYetImplemented('maximum'); - } - all(x, axes) { - return notYetImplemented('all'); - } - any(x, axes) { - return notYetImplemented('any'); - } - squaredDifference(a, b) { - return notYetImplemented('squaredDifference'); - } - ceil(x) { - return notYetImplemented('ceil'); - } - floor(x) { - return notYetImplemented('floor'); - } - round(x) { - return notYetImplemented('round'); - } - sign(x) { - return notYetImplemented('sign'); - } - isNaN(x) { - return notYetImplemented('isNaN'); - } - isInf(x) { - return notYetImplemented('isInf'); - } - isFinite(x) { - return notYetImplemented('isFinite'); - } - pow(a, b) { - return notYetImplemented('pow'); - } - exp(x) { - return notYetImplemented('exp'); - } - expm1(x) { - return notYetImplemented('expm1'); - } - softmax(x, dim) { - return notYetImplemented('softmax'); - } - log(x) { - return notYetImplemented('log'); - } - log1p(x) { - return notYetImplemented('log1p'); - } - sqrt(x) { - return notYetImplemented('sqrt'); - } - rsqrt(x) { - return notYetImplemented('rsqrt'); - } - square(x) { - return notYetImplemented('square'); - } - reciprocal(x) { - return notYetImplemented('reciprocal'); - } - relu(x) { - return notYetImplemented('relu'); - } - relu6(x) { - return notYetImplemented('relu6'); - } - prelu(x, a) { - return notYetImplemented('prelu'); - } - elu(x) { - return notYetImplemented('elu'); - } - eluDer(dy, y) { - return notYetImplemented('eluDer'); - } - selu(x) { - return notYetImplemented('selu'); - } - int(x) { - return notYetImplemented('int'); - } - clip(x, min, max) { - return notYetImplemented('clip'); - } - abs(x) { - return notYetImplemented('abs'); - } - complexAbs(x) { - return notYetImplemented('complexAbs'); - } - sigmoid(x) { - return notYetImplemented('sigmoid'); - } - softplus(x) { - return notYetImplemented('softplus'); - } - sin(x) { - return notYetImplemented('sin'); - } - cos(x) { - return notYetImplemented('cos'); - } - tan(x) { - return notYetImplemented('tan'); - } - asin(x) { - return notYetImplemented('asin'); - } - acos(x) { - return notYetImplemented('acos'); - } - atan(x) { - return notYetImplemented('atan'); - } - atan2(a, b) { - return notYetImplemented('atan2'); - } - sinh(x) { - return notYetImplemented('sinh'); - } - cosh(x) { - return notYetImplemented('cosh'); - } - tanh(x) { - return notYetImplemented('tanh'); - } - asinh(x) { - return notYetImplemented('asinh'); - } - acosh(x) { - return notYetImplemented('acosh'); - } - atanh(x) { - return notYetImplemented('atanh'); - } - erf(x) { - return notYetImplemented('erf'); - } - step(x, alpha) { - return notYetImplemented('step'); - } - fusedConv2d({ input, filter, convInfo, bias, activation, preluActivationWeights }) { - return notYetImplemented('fusedConv2d'); - } - conv2d(x, filter, convInfo) { - return notYetImplemented('conv2d'); - } - conv2dDerInput(dy, filter, convInfo) { - return notYetImplemented('conv2dDerInput'); - } - conv2dDerFilter(x, dY, convInfo) { - return notYetImplemented('conv2dDerFilter'); - } - fusedDepthwiseConv2D({ input, filter, convInfo, bias, activation, preluActivationWeights }) { - return notYetImplemented('fusedDepthwiseConv2D'); - } - depthwiseConv2D(input, filter, convInfo) { - return notYetImplemented('depthwiseConv2D'); - } - depthwiseConv2DDerInput(dy, filter, convInfo) { - return notYetImplemented('depthwiseConv2DDerInput'); - } - depthwiseConv2DDerFilter(x, dY, convInfo) { - return notYetImplemented('depthwiseConv2DDerFilter'); - } - conv3d(x, filter, convInfo) { - return notYetImplemented('conv3d'); - } - conv3dDerInput(dy, filter, convInfo) { - return notYetImplemented('conv3dDerInput'); - } - conv3dDerFilter(x, dY, convInfo) { - return notYetImplemented('conv3dDerFilter'); - } - maxPool(x, convInfo) { - return notYetImplemented('maxPool'); - } - maxPoolBackprop(dy, x, y, convInfo) { - return notYetImplemented('maxPoolBackprop'); - } - avgPool(x, convInfo) { - return notYetImplemented('avgPool'); - } - avgPoolBackprop(dy, x, convInfo) { - return notYetImplemented('avgPoolBackprop'); - } - avgPool3d(x, convInfo) { - return notYetImplemented('avgPool3d'); - } - avgPool3dBackprop(dy, x, convInfo) { - return notYetImplemented('avgPool3dBackprop'); - } - maxPool3d(x, convInfo) { - return notYetImplemented('maxPool3d'); - } - maxPool3dBackprop(dy, x, y, convInfo) { - return notYetImplemented('maxPool3dBackprop'); - } - reshape(x, shape) { - return notYetImplemented('reshape'); - } - cast(x, dtype) { - return notYetImplemented('cast'); - } - tile(x, reps) { - return notYetImplemented('tile'); - } - pad(x, paddings, constantValue) { - return notYetImplemented('pad'); - } - transpose(x, perm) { - return notYetImplemented('transpose'); - } - gather(x, indices, axis) { - return notYetImplemented('gather'); - } - gatherND(x, indices) { - return notYetImplemented('gatherND'); - } - scatterND(indices, updates, shape) { - return notYetImplemented('scatterND'); - } - batchToSpaceND(x, blockShape, crops) { - return notYetImplemented('batchToSpaceND'); - } - spaceToBatchND(x, blockShape, paddings) { - return notYetImplemented('spaceToBatchND'); - } - resizeBilinear(x, newHeight, newWidth, alignCorners) { - return notYetImplemented('resizeBilinear'); - } - resizeBilinearBackprop(dy, x, alignCorners) { - return notYetImplemented('resizeBilinearBackprop'); - } - resizeNearestNeighbor(x, newHEight, newWidth, alignCorners) { - return notYetImplemented('resizeNearestNeighbor'); - } - resizeNearestNeighborBackprop(dy, x, alignCorners) { - return notYetImplemented('resizeNearestNeighborBackprop'); - } - batchNorm(x, mean, variance, offset, scale, varianceEpsilon) { - return notYetImplemented('batchNorm'); - } - localResponseNormalization4D(x, radius, bias, alpha, beta) { - return notYetImplemented('localResponseNormalization4D'); - } - LRNGrad(dy, inputImage, outputImage, radius, bias, alpha, beta) { - return notYetImplemented('LRNGrad'); - } - multinomial(logits, normalized, numSamples, seed) { - return notYetImplemented('multinomial'); - } - oneHot(indices, depth, onValue, offValue) { - return notYetImplemented('oneHot'); - } - cumsum(x, axis, exclusive, reverse) { - return notYetImplemented('cumsum'); - } - nonMaxSuppression(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { - return notYetImplemented('nonMaxSuppression'); - } - fft(x) { - return notYetImplemented('fft'); - } - ifft(x) { - return notYetImplemented('ifft'); - } - complex(real, imag) { - return notYetImplemented('complex'); - } - real(input) { - return notYetImplemented('real'); - } - imag(input) { - return notYetImplemented('imag'); - } - cropAndResize(image, boxes, boxIndex, cropSize, method, extrapolationValue) { - return notYetImplemented('cropAndResize'); - } - depthToSpace(x, blockSize, dataFormat) { - return notYetImplemented('depthToSpace'); - } - // Aligns with the "SplitV" kernel in TensorFlow. - split(value, sizeSplits, axis) { - return notYetImplemented('split'); - } - sparseToDense(sparseIndices, sparseValues, outputShape, defaultValue) { - return notYetImplemented('sparseToDense'); - } - diag(x) { - return notYetImplemented('diag'); - } - fill(shape, value, dtype) { - return notYetImplemented('fill'); - } - onesLike(x) { - return notYetImplemented('onesLike'); - } - zerosLike(x) { - return notYetImplemented('zerosLike'); - } - linspace(start, stop, num) { - return notYetImplemented('linspace'); - } - dispose() { - return notYetImplemented('dispose'); - } - } - function notYetImplemented(kernelName) { - throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` + - `This kernel may not be supported by the tfjs backend you have chosen`); - } - - /** - * @license - * Copyright 2017 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. - const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; - /** - * The environment contains evaluated flags as well as the registered platform. - * This is always used as a global singleton and can be retrieved with - * `tf.env()`. - * - * @doc {heading: 'Environment'} - */ - class Environment { - // tslint:disable-next-line: no-any - constructor(global) { - this.global = global; - this.flags = {}; - this.flagRegistry = {}; - this.urlFlags = {}; - this.populateURLFlags(); - } - setPlatform(platformName, platform) { - if (this.platform != null) { - console.warn(`Platform ${this.platformName} has already been set. ` + - `Overwriting the platform with ${platform}.`); - } - this.platformName = platformName; - this.platform = platform; - } - registerFlag(flagName, evaluationFn, setHook) { - this.flagRegistry[flagName] = { evaluationFn, setHook }; - // Override the flag value from the URL. This has to happen here because the - // environment is initialized before flags get registered. - if (this.urlFlags[flagName] != null) { - const flagValue = this.urlFlags[flagName]; - console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`); - this.set(flagName, flagValue); - } - } - async getAsync(flagName) { - if (flagName in this.flags) { - return this.flags[flagName]; - } - this.flags[flagName] = await this.evaluateFlag(flagName); - return this.flags[flagName]; - } - get(flagName) { - if (flagName in this.flags) { - return this.flags[flagName]; - } - const flagValue = this.evaluateFlag(flagName); - if (flagValue instanceof Promise) { - throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` + - `Please use getAsync() instead.`); - } - this.flags[flagName] = flagValue; - return this.flags[flagName]; - } - getNumber(flagName) { - return this.get(flagName); - } - getBool(flagName) { - return this.get(flagName); - } - getFlags() { - return this.flags; - } - // For backwards compatibility. - get features() { - return this.flags; - } - set(flagName, value) { - if (this.flagRegistry[flagName] == null) { - throw new Error(`Cannot set flag ${flagName} as it has not been registered.`); - } - this.flags[flagName] = value; - if (this.flagRegistry[flagName].setHook != null) { - this.flagRegistry[flagName].setHook(value); - } - } - evaluateFlag(flagName) { - if (this.flagRegistry[flagName] == null) { - throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`); - } - return this.flagRegistry[flagName].evaluationFn(); - } - setFlags(flags) { - this.flags = Object.assign({}, flags); - } - reset() { - this.flags = {}; - this.urlFlags = {}; - this.populateURLFlags(); - } - populateURLFlags() { - if (typeof this.global === 'undefined' || - typeof this.global.location === 'undefined' || - typeof this.global.location.search === 'undefined') { - return; - } - const urlParams = getQueryParams(this.global.location.search); - if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { - const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); - keyValues.forEach(keyValue => { - const [key, value] = keyValue.split(':'); - this.urlFlags[key] = parseValue(key, value); - }); - } - } - } - function getQueryParams(queryString) { - const params = {}; - queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => { - decodeParam(params, t[0], t[1]); - return t.join('='); - }); - return params; - } - function decodeParam(params, name, value) { - params[decodeURIComponent(name)] = decodeURIComponent(value || ''); - } - function parseValue(flagName, value) { - value = value.toLowerCase(); - if (value === 'true' || value === 'false') { - return value === 'true'; - } - else if (`${+value}` === value) { - return +value; - } - throw new Error(`Could not parse value flag value ${value} for flag ${flagName}.`); - } - /** - * Returns the current environment (a global singleton). - * - * The environment object contains the evaluated feature values as well as the - * active platform. - * - * @doc {heading: 'Environment'} - */ - function env() { - return exports.ENV; - } - exports.ENV = null; - function setEnvironmentGlobal(environment) { - exports.ENV = environment; - } - - /** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - // Note that the identifier globalNameSpace is scoped to this module, but will - // always resolve to the same global object regardless of how the module is - // resolved. - // tslint:disable-next-line:no-any - let globalNameSpace; - // tslint:disable-next-line:no-any - function getGlobalNamespace() { - if (globalNameSpace == null) { - // tslint:disable-next-line:no-any - let ns; - if (typeof (window) !== 'undefined') { - ns = window; - } - else if (typeof (global) !== 'undefined') { - ns = global; - } - else if (typeof (process) !== 'undefined') { - ns = process; - } - else if (typeof (self) !== 'undefined') { - ns = self; - } - else { - throw new Error('Could not find a global object'); - } - globalNameSpace = ns; - } - return globalNameSpace; - } - // tslint:disable-next-line:no-any - function getGlobalMap() { - const ns = getGlobalNamespace(); - if (ns._tfGlobals == null) { - ns._tfGlobals = new Map(); - } - return ns._tfGlobals; - } - /** - * Returns a globally accessible 'singleton' object. - * - * @param key the name of the object - * @param init a function to initialize to initialize this object - * the first time it is fetched. - */ - function getGlobal(key, init) { - const globalMap = getGlobalMap(); - if (globalMap.has(key)) { - return globalMap.get(key); - } - else { - const singleton = init(); - globalMap.set(key, singleton); - return globalMap.get(key); - } - } - - const Abs = 'Abs'; - const Acos = 'Acos'; - const Acosh = 'Acosh'; - const Add = 'Add'; - const AddN = 'AddN'; - const All = 'All'; - const Any = 'Any'; - const ArgMax = 'ArgMax'; - const ArgMin = 'ArgMin'; - const Asin = 'Asin'; - const Asinh = 'Asinh'; - const Atan = 'Atan'; - const Atanh = 'Atanh'; - const Atan2 = 'Atan2'; - const AvgPool = 'AvgPool'; - const AvgPoolBackprop = 'AvgPoolBackprop'; - const AvgPool3D = 'AvgPool3D'; - const AvgPool3DBackprop = 'AvgPool3DBackprop'; - const BatchMatMul = 'BatchMatMul'; - const BatchToSpaceND = 'BatchToSpaceND'; - const BroadcastTo = 'BroadcastTo'; - const Cast = 'Cast'; - const Ceil = 'Ceil'; - const ClipByValue = 'ClipByValue'; - const Complex = 'Complex'; - const Concat = 'Concat'; - const Conv2D = 'Conv2D'; - const Conv2DBackpropFilter = 'Conv2DBackpropFilter'; - const Conv2DBackpropInput = 'Conv2DBackpropInput'; - const Conv3D = 'Conv3D'; - const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2'; - const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2'; - const Cos = 'Cos'; - const Cosh = 'Cosh'; - const Cumsum = 'Cumsum'; - const CropAndResize = 'CropAndResize'; - const DepthToSpace = 'DepthToSpace'; - const DepthwiseConv2dNative = 'DepthwiseConv2dNative'; - const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter'; - const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput'; - const Diag = 'Diag'; - const Dilation2D = 'Dilation2D'; - const Dilation2DBackpropInput = 'Dilation2DBackpropInput'; - const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter'; - const Div = 'Div'; - const Elu = 'Elu'; - const EluGrad = 'EluGrad'; - const Erf = 'Erf'; - const Equal = 'Equal'; - const Exp = 'Exp'; - const Expm1 = 'Expm1'; - const FFT = 'FFT'; - const Fill = 'Fill'; - const FlipLeftRight = 'FlipLeftRight'; - const Floor = 'Floor'; - const FloorDiv = 'FloorDiv'; - const FusedBatchNorm = 'FusedBatchNorm'; - const GatherV2 = 'GatherV2'; - const GatherNd = 'GatherNd'; - const Greater = 'Greater'; - const GreaterEqual = 'GreaterEqual'; - const Identity = 'Identity'; - const IFFT = 'IFFT'; - const Imag = 'Imag'; - const IsFinite = 'IsFinite'; - const IsInf = 'IsInf'; - const IsNan = 'IsNan'; - const Less = 'Less'; - const LessEqual = 'LessEqual'; - const LinSpace = 'LinSpace'; - const Log = 'Log'; - const Log1p = 'Log1p'; - const LogicalAnd = 'LogicalAnd'; - const LogicalNot = 'LogicalNot'; - const LogicalOr = 'LogicalOr'; - const LogSoftmax = 'LogSoftmax'; - const LRN = 'LRN'; - const LRNBackprop = 'LRNBackprop'; - const Max = 'Max'; - const Maximum = 'Maximum'; - const MaxPool = 'MaxPool'; - const MaxPoolBackprop = 'MaxPoolBackprop'; - const MaxPool3D = 'MaxPool3D'; - const MaxPool3DBackprop = 'MaxPool3DBackprop'; - const MaxPoolWithArgmax = 'MaxPoolWithArgmax'; - const Mean = 'Mean'; - const Min = 'Min'; - const Minimum = 'Minimum'; - const MirrorPad = 'MirrorPad'; - const Mod = 'Mod'; - const Multiply = 'Multiply'; - const Negate = 'Negate'; - const NotEqual = 'NotEqual'; - const NonMaxSuppressionV3 = 'NonMaxSuppressionV3'; - const NonMaxSuppressionV4 = 'NonMaxSuppressionV4'; - const NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; - const OnesLike = 'OnesLike'; - const OneHot = 'OneHot'; - const PadV2 = 'PadV2'; - const Pool = 'Pool'; - const Pow = 'Pow'; - const Prelu = 'Prelu'; - const Prod = 'Prod'; - const Range = 'Range'; - const Real = 'Real'; - const Reciprocal = 'Reciprocal'; - const Relu = 'Relu'; - const Reshape = 'Reshape'; - const ResizeNearestNeighbor = 'ResizeNearestNeighbor'; - const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; - const ResizeBilinear = 'ResizeBilinear'; - const ResizeBilinearGrad = 'ResizeBilinearGrad'; - const Relu6 = 'Relu6'; - const Reverse = 'Reverse'; - const Round = 'Round'; - const Rsqrt = 'Rsqrt'; - const ScatterNd = 'ScatterNd'; - const SelectV2 = 'SelectV2'; - const Selu = 'Selu'; - const Slice = 'Slice'; - const Sin = 'Sin'; - const Sinh = 'Sinh'; - const Sign = 'Sign'; - const Sigmoid = 'Sigmoid'; - const Softplus = 'Softplus'; - const Sqrt = 'Sqrt'; - const Sum = 'Sum'; - const SpaceToBatchND = 'SpaceToBatchND'; - const SplitV = 'SplitV'; - const Softmax = 'Softmax'; - const SquaredDifference = 'SquaredDifference'; - const Square = 'Square'; - const Sub = 'Sub'; - const SparseToDense = 'SparseToDense'; - const StridedSlice = 'StridedSlice'; - const Tan = 'Tan'; - const Tanh = 'Tanh'; - const Tile = 'Tile'; - const TopK = 'TopK'; - const Transpose = 'Transpose'; - const Unique = 'Unique'; - const Unpack = 'Unpack'; - const UnsortedSegmentSum = 'UnsortedSegmentSum'; - const ZerosLike = 'ZerosLike'; - /** - * TensorFlow.js-only kernels - */ - const Step = 'Step'; - const FromPixels = 'FromPixels'; - const RotateWithOffset = 'RotateWithOffset'; - const _FusedMatMul = '_FusedMatMul'; - const FusedConv2D = 'FusedConv2D'; - const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D'; - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const kernelRegistry = getGlobal('kernelRegistry', () => new Map()); - const gradRegistry = getGlobal('gradRegistry', () => new Map()); - /** - * Returns the kernel function (code) associated with the provided names. - * - * @param kernelName The official name of the kernel. - * @param backendName The official name of the backend. - */ - function getKernel(kernelName, backendName) { - const key = makeKey(kernelName, backendName); - return kernelRegistry.get(key); - } - /** - * Returns the registered gradient info associated with the provided kernel. - * @param kernelName The official TF kernel name. - */ - function getGradient(kernelName) { - return gradRegistry.get(kernelName); - } - function getKernelsForBackend(backendName) { - const it = kernelRegistry.entries(); - const result = []; - while (true) { - const { done, value } = it.next(); - if (done) { - break; - } - const [key, config] = value; - const [backend,] = key.split('_'); - if (backend === backendName) { - result.push(config); - } - } - return result; - } - /** - * Registers the function (forward pass) for the kernel in a global registry. - * - * @param config A config object with the following properties: - * - `kernelName` The official name of the kernel. - * - `backendName` The official name of the backend. - * - `kernelFunc` The function to run during the forward pass of the kernel. - * - `setupFunc` Optional. Gets called once, after the backend initializes. - * - `disposeFunc` Optional. Gets called once, right before the backend is - * disposed. - */ - function registerKernel(config) { - const { kernelName, backendName } = config; - const key = makeKey(kernelName, backendName); - if (kernelRegistry.has(key)) { - console.warn(`The kernel '${kernelName}' for backend ` + - `'${backendName}' is already registered`); - } - kernelRegistry.set(key, config); - } - /** - * Registers a gradient function for a given kernel in the global registry, - * to be used during the back-propagation of that kernel. - * - * @param config An object with the following properties: - * - `kernelName` The name of the kernel that the gradient function is for. - * - `gradFunc` The function to run during back-propagation. - */ - function registerGradient(config) { - const { kernelName } = config; - if (gradRegistry.has(kernelName)) { - // TODO (yassogba) after 3.0 assess whether we need to keep this gated - // to debug mode. - if (env().getBool('DEBUG')) { - console.warn(`Overriding the gradient for '${kernelName}'`); - } - } - gradRegistry.set(kernelName, config); - } - /** - * Removes the kernel function from the registry. - * - * @param kernelName The official name of the kernel. - * @param backendName The official name of the backend. - * - */ - function unregisterKernel(kernelName, backendName) { - const key = makeKey(kernelName, backendName); - if (!kernelRegistry.has(key)) { - throw new Error(`The kernel '${kernelName}' for backend ` + - `'${backendName}' is not registered`); - } - kernelRegistry.delete(key); - } - /** Removes the registered gradient from the global registry. */ - function unregisterGradient(kernelName) { - if (!gradRegistry.has(kernelName)) { - throw new Error(`The gradient '${kernelName}' for backend is not registered`); - } - gradRegistry.delete(kernelName); - } - /** - * Finds kernels that have already been registered to a backend and re-registers - * them for a new backend. Useful for registering custom backends. - * @param registeredBackendName Already registered backend. - * @param newBackendName New backend. - */ - function copyRegisteredKernels(registeredBackendName, newBackendName) { - const kernels = getKernelsForBackend(registeredBackendName); - kernels.forEach(kernelConfig => { - const newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName }); - registerKernel(newKernelConfig); - }); - } - function makeKey(kernelName, backendName) { - return `${backendName}_${kernelName}`; - } - - /** - * @license - * Copyright 2017 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Shuffles the array in-place using Fisher-Yates algorithm. - * - * ```js - * const a = [1, 2, 3, 4, 5]; - * tf.util.shuffle(a); - * console.log(a); - * ``` - * - * @param array The array to shuffle in-place. - * - * @doc {heading: 'Util', namespace: 'util'} - */ - // tslint:disable-next-line:no-any - function shuffle(array) { - let counter = array.length; - let temp = 0; - let index = 0; - // While there are elements in the array - while (counter > 0) { - // Pick a random index - index = (Math.random() * counter) | 0; - // Decrease counter by 1 - counter--; - // And swap the last element with it - temp = array[counter]; - array[counter] = array[index]; - array[index] = temp; - } - } - /** Clamps a value to a specified range. */ - function clamp(min, x, max) { - return Math.max(min, Math.min(x, max)); - } - function nearestLargerEven(val) { - return val % 2 === 0 ? val : val + 1; - } - function sum(arr) { - let sum = 0; - for (let i = 0; i < arr.length; i++) { - sum += arr[i]; - } - return sum; - } - /** - * Returns a sample from a uniform [a, b) distribution. - * - * @param a The minimum support (inclusive). - * @param b The maximum support (exclusive). - * @return A pseudorandom number on the half-open interval [a,b). - */ - function randUniform(a, b) { - const r = Math.random(); - return (b * r) + (1 - r) * a; - } - /** Returns the squared Euclidean distance between two vectors. */ - function distSquared(a, b) { - let result = 0; - for (let i = 0; i < a.length; i++) { - const diff = Number(a[i]) - Number(b[i]); - result += diff * diff; - } - return result; - } - /** - * Asserts that the expression is true. Otherwise throws an error with the - * provided message. - * - * ```js - * const x = 2; - * tf.util.assert(x === 2, 'x is not 2'); - * ``` - * - * @param expr The expression to assert (as a boolean). - * @param msg A function that returns the message to report when throwing an - * error. We use a function for performance reasons. - * - * @doc {heading: 'Util', namespace: 'util'} - */ - function assert(expr, msg) { - if (!expr) { - throw new Error(typeof msg === 'string' ? msg : msg()); - } - } - function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') { - assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); - } - function assertNonNull(a) { - assert(a != null, () => `The input to the tensor constructor must be a non-null value.`); - } - // NOTE: We explicitly type out what T extends instead of any so that - // util.flatten on a nested array of number doesn't try to infer T as a - // number[][], causing us to explicitly type util.flatten(). - /** - * Flattens an arbitrarily nested array. - * - * ```js - * const a = [[1, 2], [3, 4], [5, [6, [7]]]]; - * const flat = tf.util.flatten(a); - * console.log(flat); - * ``` - * - * @param arr The nested array to flatten. - * @param result The destination array which holds the elements. - * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults - * to false. - * - * @doc {heading: 'Util', namespace: 'util'} - */ - function flatten(arr, result = [], skipTypedArray = false) { - if (result == null) { - result = []; - } - if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { - for (let i = 0; i < arr.length; ++i) { - flatten(arr[i], result, skipTypedArray); - } - } - else { - result.push(arr); - } - return result; - } - /** - * Returns the size (number of elements) of the tensor given its shape. - * - * ```js - * const shape = [3, 4, 2]; - * const size = tf.util.sizeFromShape(shape); - * console.log(size); - * ``` - * - * @doc {heading: 'Util', namespace: 'util'} - */ - function sizeFromShape(shape) { - if (shape.length === 0) { - // Scalar. - return 1; - } - let size = shape[0]; - for (let i = 1; i < shape.length; i++) { - size *= shape[i]; - } - return size; - } - function isScalarShape(shape) { - return shape.length === 0; - } - function arraysEqual(n1, n2) { - if (n1 === n2) { - return true; - } - if (n1 == null || n2 == null) { - return false; - } - if (n1.length !== n2.length) { - return false; - } - for (let i = 0; i < n1.length; i++) { - if (n1[i] !== n2[i]) { - return false; - } - } - return true; - } - function isInt(a) { - return a % 1 === 0; - } - function tanh(x) { - // tslint:disable-next-line:no-any - if (Math.tanh != null) { - // tslint:disable-next-line:no-any - return Math.tanh(x); - } - if (x === Infinity) { - return 1; - } - else if (x === -Infinity) { - return -1; - } - else { - const e2x = Math.exp(2 * x); - return (e2x - 1) / (e2x + 1); - } - } - function sizeToSquarishShape(size) { - const width = Math.ceil(Math.sqrt(size)); - return [width, Math.ceil(size / width)]; - } - /** - * Creates a new array with randomized indicies to a given quantity. - * - * ```js - * const randomTen = tf.util.createShuffledIndices(10); - * console.log(randomTen); - * ``` - * - * @param number Quantity of how many shuffled indicies to create. - * - * @doc {heading: 'Util', namespace: 'util'} - */ - function createShuffledIndices(n) { - const shuffledIndices = new Uint32Array(n); - for (let i = 0; i < n; ++i) { - shuffledIndices[i] = i; - } - shuffle(shuffledIndices); - return shuffledIndices; - } - function rightPad(a, size) { - if (size <= a.length) { - return a; - } - return a + ' '.repeat(size - a.length); - } - function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter) { - return new Promise((resolve, reject) => { - let tryCount = 0; - const tryFn = () => { - if (checkFn()) { - resolve(); - return; - } - tryCount++; - const nextBackoff = delayFn(tryCount); - if (maxCounter != null && tryCount >= maxCounter) { - reject(); - return; - } - setTimeout(tryFn, nextBackoff); - }; - tryFn(); - }); - } - /** - * Given the full size of the array and a shape that may contain -1 as the - * implicit dimension, returns the inferred shape where -1 is replaced. - * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. - * - * @param shape The shape, which may contain -1 in some dimension. - * @param size The full size (number of elements) of the array. - * @return The inferred shape where -1 is replaced with the inferred size. - */ - function inferFromImplicitShape(shape, size) { - let shapeProd = 1; - let implicitIdx = -1; - for (let i = 0; i < shape.length; ++i) { - if (shape[i] >= 0) { - shapeProd *= shape[i]; - } - else if (shape[i] === -1) { - if (implicitIdx !== -1) { - throw Error(`Shapes can only have 1 implicit size. ` + - `Found -1 at dim ${implicitIdx} and dim ${i}`); - } - implicitIdx = i; - } - else if (shape[i] < 0) { - throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`); - } - } - if (implicitIdx === -1) { - if (size > 0 && size !== shapeProd) { - throw Error(`Size(${size}) must match the product of shape ${shape}`); - } - return shape; - } - if (shapeProd === 0) { - throw Error(`Cannot infer the missing size in [${shape}] when ` + - `there are 0 elements`); - } - if (size % shapeProd !== 0) { - throw Error(`The implicit shape can't be a fractional number. ` + - `Got ${size} / ${shapeProd}`); - } - const newShape = shape.slice(); - newShape[implicitIdx] = size / shapeProd; - return newShape; - } - function parseAxisParam(axis, shape) { - const rank = shape.length; - // Normalize input - axis = axis == null ? shape.map((s, i) => i) : [].concat(axis); - // Check for valid range - assert(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` + - `got axis ${axis}`); - // Check for only integers - assert(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` + - `got axis ${axis}`); - // Handle negative axis. - return axis.map(a => a < 0 ? rank + a : a); - } - /** Reduces the shape by removing all dimensions of shape 1. */ - function squeezeShape(shape, axis) { - const newShape = []; - const keptDims = []; - const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; - const axes = (axis == null || isEmptyArray) ? - null : - parseAxisParam(axis, shape).sort(); - let j = 0; - for (let i = 0; i < shape.length; ++i) { - if (axes != null) { - if (axes[j] === i && shape[i] !== 1) { - throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); - } - if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { - newShape.push(shape[i]); - keptDims.push(i); - } - if (axes[j] <= i) { - j++; - } - } - if (shape[i] !== 1) { - newShape.push(shape[i]); - keptDims.push(i); - } - } - return { newShape, keptDims }; - } - function getTypedArrayFromDType(dtype, size) { - let values = null; - if (dtype == null || dtype === 'float32') { - values = new Float32Array(size); - } - else if (dtype === 'int32') { - values = new Int32Array(size); - } - else if (dtype === 'bool') { - values = new Uint8Array(size); - } - else { - throw new Error(`Unknown data type ${dtype}`); - } - return values; - } - function getArrayFromDType(dtype, size) { - let values = null; - if (dtype == null || dtype === 'float32') { - values = new Float32Array(size); - } - else if (dtype === 'int32') { - values = new Int32Array(size); - } - else if (dtype === 'bool') { - values = new Uint8Array(size); - } - else if (dtype === 'string') { - values = new Array(size); - } - else { - throw new Error(`Unknown data type ${dtype}`); - } - return values; - } - function checkConversionForErrors(vals, dtype) { - for (let i = 0; i < vals.length; i++) { - const num = vals[i]; - if (isNaN(num) || !isFinite(num)) { - throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`); - } - } - } - /** Returns true if the dtype is valid. */ - function isValidDtype(dtype) { - return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || - dtype === 'int32' || dtype === 'string'; - } - /** - * Returns true if the new type can't encode the old type without loss of - * precision. - */ - function hasEncodingLoss(oldType, newType) { - if (newType === 'complex64') { - return false; - } - if (newType === 'float32' && oldType !== 'complex64') { - return false; - } - if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { - return false; - } - if (newType === 'bool' && oldType === 'bool') { - return false; - } - return true; - } - function isTypedArray(a) { - return a instanceof Float32Array || a instanceof Int32Array || - a instanceof Uint8Array; - } - function bytesPerElement(dtype) { - if (dtype === 'float32' || dtype === 'int32') { - return 4; - } - else if (dtype === 'complex64') { - return 8; - } - else if (dtype === 'bool') { - return 1; - } - else { - throw new Error(`Unknown dtype ${dtype}`); - } - } - /** - * Returns the approximate number of bytes allocated in the string array - 2 - * bytes per character. Computing the exact bytes for a native string in JS is - * not possible since it depends on the encoding of the html page that serves - * the website. - */ - function bytesFromStringArray(arr) { - if (arr == null) { - return 0; - } - let bytes = 0; - arr.forEach(x => bytes += x.length); - return bytes; - } - /** Returns true if the value is a string. */ - function isString(value) { - return typeof value === 'string' || value instanceof String; - } - function isBoolean(value) { - return typeof value === 'boolean'; - } - function isNumber(value) { - return typeof value === 'number'; - } - function inferDtype(values) { - if (Array.isArray(values)) { - return inferDtype(values[0]); - } - if (values instanceof Float32Array) { - return 'float32'; - } - else if (values instanceof Int32Array || values instanceof Uint8Array) { - return 'int32'; - } - else if (isNumber(values)) { - return 'float32'; - } - else if (isString(values)) { - return 'string'; - } - else if (isBoolean(values)) { - return 'bool'; - } - return 'float32'; - } - function isFunction(f) { - return !!(f && f.constructor && f.call && f.apply); - } - function nearestDivisor(size, start) { - for (let i = start; i < size; ++i) { - if (size % i === 0) { - return i; - } - } - return size; - } - function computeStrides(shape) { - const rank = shape.length; - if (rank < 2) { - return []; - } - // Last dimension has implicit stride of 1, thus having D-1 (instead of D) - // strides. - const strides = new Array(rank - 1); - strides[rank - 2] = shape[rank - 1]; - for (let i = rank - 3; i >= 0; --i) { - strides[i] = strides[i + 1] * shape[i + 1]; - } - return strides; - } - /** - * Create typed array for scalar value. Used for storing in `DataStorage`. - */ - function createScalarValue(value, dtype) { - if (dtype === 'string') { - return encodeString(value); - } - return toTypedArray([value], dtype); - } - function toTypedArray(a, dtype) { - if (dtype === 'string') { - throw new Error('Cannot convert a string[] to a TypedArray'); - } - if (Array.isArray(a)) { - a = flatten(a); - } - if (env().getBool('DEBUG')) { - checkConversionForErrors(a, dtype); - } - if (noConversionNeeded(a, dtype)) { - return a; - } - if (dtype == null || dtype === 'float32' || dtype === 'complex64') { - return new Float32Array(a); - } - else if (dtype === 'int32') { - return new Int32Array(a); - } - else if (dtype === 'bool') { - const bool = new Uint8Array(a.length); - for (let i = 0; i < bool.length; ++i) { - if (Math.round(a[i]) !== 0) { - bool[i] = 1; - } - } - return bool; - } - else { - throw new Error(`Unknown data type ${dtype}`); - } - } - function createNestedArray(offset, shape, a) { - const ret = new Array(); - if (shape.length === 1) { - const d = shape[0]; - for (let i = 0; i < d; i++) { - ret[i] = a[offset + i]; - } - } - else { - const d = shape[0]; - const rest = shape.slice(1); - const len = rest.reduce((acc, c) => acc * c); - for (let i = 0; i < d; i++) { - ret[i] = createNestedArray(offset + i * len, rest, a); - } - } - return ret; - } - // Provide a nested array of TypedArray in given shape. - function toNestedArray(shape, a) { - if (shape.length === 0) { - // Scalar type should return a single number. - return a[0]; - } - const size = shape.reduce((acc, c) => acc * c); - if (size === 0) { - // A tensor with shape zero should be turned into empty list. - return []; - } - if (size !== a.length) { - throw new Error(`[${shape}] does not match the input size ${a.length}.`); - } - return createNestedArray(0, shape, a); - } - function noConversionNeeded(a, dtype) { - return (a instanceof Float32Array && dtype === 'float32') || - (a instanceof Int32Array && dtype === 'int32') || - (a instanceof Uint8Array && dtype === 'bool'); - } - function makeOnesTypedArray(size, dtype) { - const array = makeZerosTypedArray(size, dtype); - for (let i = 0; i < array.length; i++) { - array[i] = 1; - } - return array; - } - function makeZerosTypedArray(size, dtype) { - if (dtype == null || dtype === 'float32' || dtype === 'complex64') { - return new Float32Array(size); - } - else if (dtype === 'int32') { - return new Int32Array(size); - } - else if (dtype === 'bool') { - return new Uint8Array(size); - } - else { - throw new Error(`Unknown data type ${dtype}`); - } - } - /** - * Make nested `TypedArray` filled with zeros. - * @param shape The shape information for the nested array. - * @param dtype dtype of the array element. - */ - function makeZerosNestedTypedArray(shape, dtype) { - const size = shape.reduce((prev, curr) => prev * curr, 1); - if (dtype == null || dtype === 'float32') { - return toNestedArray(shape, new Float32Array(size)); - } - else if (dtype === 'int32') { - return toNestedArray(shape, new Int32Array(size)); - } - else if (dtype === 'bool') { - return toNestedArray(shape, new Uint8Array(size)); - } - else { - throw new Error(`Unknown data type ${dtype}`); - } - } - /** - * Returns the current high-resolution time in milliseconds relative to an - * arbitrary time in the past. It works across different platforms (node.js, - * browsers). - * - * ```js - * console.log(tf.util.now()); - * ``` - * - * @doc {heading: 'Util', namespace: 'util'} - */ - function now() { - return env().platform.now(); - } - function assertNonNegativeIntegerDimensions(shape) { - shape.forEach(dimSize => { - assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` + - `shape [${shape}].`); - }); - } - /** - * Returns a platform-specific implementation of - * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). - * - * If `fetch` is defined on the global object (`window`, `process`, etc.), - * `tf.util.fetch` returns that function. - * - * If not, `tf.util.fetch` returns a platform-specific solution. - * - * ```js - * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs'); - * // handle response - * ``` - * - * @doc {heading: 'Util'} - */ - function fetch$1(path, requestInits) { - return env().platform.fetch(path, requestInits); - } - /** - * Encodes the provided string into bytes using the provided encoding scheme. - * - * @param s The string to encode. - * @param encoding The encoding scheme. Defaults to utf-8. - * - * @doc {heading: 'Util'} - */ - function encodeString(s, encoding = 'utf-8') { - encoding = encoding || 'utf-8'; - return env().platform.encode(s, encoding); - } - /** - * Decodes the provided bytes into a string using the provided encoding scheme. - * @param bytes The bytes to decode. - * - * @param encoding The encoding scheme. Defaults to utf-8. - * - * @doc {heading: 'Util'} - */ - function decodeString(bytes, encoding = 'utf-8') { - encoding = encoding || 'utf-8'; - return env().platform.decode(bytes, encoding); - } - /** - * Computes flat index for a given location (multidimentionsal index) in a - * Tensor/multidimensional array. - * - * @param locs Location in the tensor. - * @param rank Rank of the tensor. - * @param strides Tensor strides. - */ - function locToIndex(locs, rank, strides) { - if (rank === 0) { - return 0; - } - else if (rank === 1) { - return locs[0]; - } - let index = locs[locs.length - 1]; - for (let i = 0; i < locs.length - 1; ++i) { - index += strides[i] * locs[i]; - } - return index; - } - /** - * Computes the location (multidimensional index) in a tensor/multidimentional - * array for a given flat index. - * - * @param index Index in flat array. - * @param rank Rank of tensor. - * @param strides Strides of tensor. - */ - function indexToLoc(index, rank, strides) { - if (rank === 0) { - return []; - } - else if (rank === 1) { - return [index]; - } - const locs = new Array(rank); - for (let i = 0; i < locs.length - 1; ++i) { - locs[i] = Math.floor(index / strides[i]); - index -= locs[i] * strides[i]; - } - locs[locs.length - 1] = index; - return locs; - } - - var util = /*#__PURE__*/Object.freeze({ - __proto__: null, - shuffle: shuffle, - clamp: clamp, - nearestLargerEven: nearestLargerEven, - sum: sum, - randUniform: randUniform, - distSquared: distSquared, - assert: assert, - assertShapesMatch: assertShapesMatch, - assertNonNull: assertNonNull, - flatten: flatten, - sizeFromShape: sizeFromShape, - isScalarShape: isScalarShape, - arraysEqual: arraysEqual, - isInt: isInt, - tanh: tanh, - sizeToSquarishShape: sizeToSquarishShape, - createShuffledIndices: createShuffledIndices, - rightPad: rightPad, - repeatedTry: repeatedTry, - inferFromImplicitShape: inferFromImplicitShape, - parseAxisParam: parseAxisParam, - squeezeShape: squeezeShape, - getTypedArrayFromDType: getTypedArrayFromDType, - getArrayFromDType: getArrayFromDType, - checkConversionForErrors: checkConversionForErrors, - isValidDtype: isValidDtype, - hasEncodingLoss: hasEncodingLoss, - isTypedArray: isTypedArray, - bytesPerElement: bytesPerElement, - bytesFromStringArray: bytesFromStringArray, - isString: isString, - isBoolean: isBoolean, - isNumber: isNumber, - inferDtype: inferDtype, - isFunction: isFunction, - nearestDivisor: nearestDivisor, - computeStrides: computeStrides, - createScalarValue: createScalarValue, - toTypedArray: toTypedArray, - toNestedArray: toNestedArray, - makeOnesTypedArray: makeOnesTypedArray, - makeZerosTypedArray: makeZerosTypedArray, - makeZerosNestedTypedArray: makeZerosNestedTypedArray, - now: now, - assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions, - fetch: fetch$1, - encodeString: encodeString, - decodeString: decodeString, - locToIndex: locToIndex, - indexToLoc: indexToLoc - }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - class Profiler { - constructor(backendTimer, logger) { - this.backendTimer = backendTimer; - this.logger = logger; - if (logger == null) { - this.logger = new Logger(); - } - } - profileKernel(kernelName, inputs, f) { - let outputs; - const holdResultWrapperFn = () => { - outputs = f(); - }; - const timer = this.backendTimer.time(holdResultWrapperFn); - for (let i = 0; i < outputs.length; i++) { - const output = outputs[i]; - // Dangling promise here because we don't want to propagate up - // asynchronicity. - output.data().then(tensorVals => { - checkComputationForErrors(tensorVals, output.dtype, kernelName); - }); - } - const kernelProfile = { - kernelName, - outputs, - inputs, - timeMs: timer.then(timing => timing.kernelMs), - extraInfo: timer.then(timing => timing.getExtraProfileInfo != null ? - timing.getExtraProfileInfo() : - '') - }; - return kernelProfile; - } - logKernelProfile(kernelProfile) { - const { kernelName, outputs, timeMs, inputs, extraInfo } = kernelProfile; - outputs.forEach(result => { - Promise.all([result.data(), timeMs, extraInfo]).then(valueContainer => { - this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]); - }); - }); - } - } - function checkComputationForErrors(vals, dtype, kernelName) { - if (dtype !== 'float32') { - // Only floating point computations will generate NaN values - return false; - } - for (let i = 0; i < vals.length; i++) { - const num = vals[i]; - if (isNaN(num) || !isFinite(num)) { - // Throwing custom exception so behavior is testable. - console.warn(`Found ${num} in the result of '${kernelName}'`); - return true; - } - } - return false; - } - class Logger { - logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) { - const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) : - timeMs['error']; - const paddedName = rightPad(name, 25); - const rank = result.rank; - const size = result.size; - const shape = rightPad(result.shape.toString(), 14); - let inputShapesDescription = ''; - for (const name in inputs) { - const input = inputs[name]; - if (input != null) { - // The input might be a non-tensor (e.g HTMLImageElement), in which case - // we claim the output shape as input shape. - const inputShape = input.shape || result.shape; - const inputRank = inputShape.length; - inputShapesDescription += - `${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `; - } - } - console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue'); - } - } - - /** - * @license - * Copyright 2017 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Computes a list of TapeNodes that connect x to y, filtering everything else - * out and preserving the order of the original tape elements. - * - * @param tape The tape elements to filter. - * @param xs The input Tensors. - * @param y The output Tensor. - */ - function getFilteredNodesXToY(tape, xs, y) { - // Forward pass to compute all the nodes and Tensors that are transitively a - // function of x. - const tensorsFromX = {}; - const nodesFromX = {}; - for (let i = 0; i < xs.length; i++) { - tensorsFromX[xs[i].id] = true; - } - for (let i = 0; i < tape.length; i++) { - const node = tape[i]; - const nodeInputs = node.inputs; - for (const inputName in nodeInputs) { - const input = nodeInputs[inputName]; - let anyInputFromX = false; - for (let j = 0; j < xs.length; j++) { - if (tensorsFromX[input.id]) { - node.outputs.forEach(output => tensorsFromX[output.id] = true); - anyInputFromX = true; - nodesFromX[node.id] = true; - break; - } - } - if (anyInputFromX) { - break; - } - } - } - // Backward pass to find all of the nodes and Tensors that lead to y. - const tensorsLeadToY = {}; - tensorsLeadToY[y.id] = true; - const nodesToY = {}; - for (let i = tape.length - 1; i >= 0; i--) { - const node = tape[i]; - const nodeInputs = node.inputs; - // If any of the outputs lead to y, mark all of the inputs as leading to y. - for (let j = 0; j < node.outputs.length; j++) { - if (tensorsLeadToY[node.outputs[j].id]) { - for (const inputName in nodeInputs) { - tensorsLeadToY[nodeInputs[inputName].id] = true; - nodesToY[node.id] = true; - } - break; - } - } - } - // Return the paths that come from x and lead to y. - const filteredTape = []; - for (let i = 0; i < tape.length; i++) { - const node = tape[i]; - if (nodesFromX[node.id] && nodesToY[node.id]) { - // Prune the inputs from the node that aren't a function of x. - const prunedInputs = {}; - for (const inputName in node.inputs) { - const nodeInput = node.inputs[inputName]; - if (tensorsFromX[nodeInput.id]) { - prunedInputs[inputName] = nodeInput; - } - } - // Copy the node and overwrite inputsAndArgs to the pruned version. - const prunedNode = Object.assign({}, node); - prunedNode.inputs = prunedInputs; - prunedNode.outputs = node.outputs; - filteredTape.push(prunedNode); - } - } - return filteredTape; - } - /** - * Backpropagate gradients through the filtered TapeNodes. - * - * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map - * is mutated by this method. - * @param filteredTape The filtered TapeNodes to backprop through. - */ - function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) { - // Walk the tape backward and keep a map of Tensor to its gradient. - for (let i = filteredTape.length - 1; i >= 0; i--) { - const node = filteredTape[i]; - const dys = []; - node.outputs.forEach(o => { - const gradTensor = tensorAccumulatedGradientMap[o.id]; - if (gradTensor != null) { - dys.push(gradTensor); - } - else { - // This particular output is not in the back-propagation subgraph, so it - // does not affect the final output, thus we put null for its dy. - dys.push(null); - } - }); - if (node.gradient == null) { - throw new Error(`Cannot compute gradient: gradient function not found ` + - `for ${node.kernelName}.`); - } - // Backprop dy through this node and accumulate gradients over the inputs. - const inputGradients = node.gradient(dys); - for (const inputName in node.inputs) { - if (!(inputName in inputGradients)) { - throw new Error(`Cannot backprop through input ${inputName}. ` + - `Available gradients found: ${Object.keys(inputGradients)}.`); - } - // Call the gradient function. - const dx = tidy(() => inputGradients[inputName]()); - if (dx.dtype !== 'float32') { - throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` + - `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`); - } - const x = node.inputs[inputName]; - if (!arraysEqual(dx.shape, x.shape)) { - throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` + - `'${inputName}' has shape '${dx.shape}', which does not match ` + - `the shape of the input '${x.shape}'`); - } - if (tensorAccumulatedGradientMap[x.id] == null) { - tensorAccumulatedGradientMap[x.id] = dx; - } - else { - const curGradient = tensorAccumulatedGradientMap[x.id]; - tensorAccumulatedGradientMap[x.id] = add(curGradient, dx); - curGradient.dispose(); - } - } - } - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - // Maximum number of values before we decide to show ellipsis. - const FORMAT_LIMIT_NUM_VALS = 20; - // Number of first and last values to show when displaying a, b,...,y, z. - const FORMAT_NUM_FIRST_LAST_VALS = 3; - // Number of significant digits to show. - const FORMAT_NUM_SIG_DIGITS = 7; - function tensorToString(vals, shape, dtype, verbose) { - const strides = computeStrides(shape); - const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); - const rank = shape.length; - const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); - const lines = ['Tensor']; - if (verbose) { - lines.push(` dtype: ${dtype}`); - lines.push(` rank: ${rank}`); - lines.push(` shape: [${shape}]`); - lines.push(` values:`); - } - lines.push(valsLines.map(l => ' ' + l).join('\n')); - return lines.join('\n'); - } - function computeMaxSizePerColumn(vals, shape, dtype, strides) { - const n = sizeFromShape(shape); - const numCols = strides[strides.length - 1]; - const padPerCol = new Array(numCols).fill(0); - const rank = shape.length; - const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals; - if (rank > 1) { - for (let row = 0; row < n / numCols; row++) { - const offset = row * numCols; - for (let j = 0; j < numCols; j++) { - padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); - } - } - } - return padPerCol; - } - function valToString(val, pad, dtype) { - let valStr; - if (Array.isArray(val)) { - valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` + - `${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`; - } - else if (isString(val)) { - valStr = `'${val}'`; - } - else if (dtype === 'bool') { - valStr = boolNumToString(val); - } - else { - valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); - } - return rightPad(valStr, pad); - } - function boolNumToString(v) { - return v === 0 ? 'false' : 'true'; - } - function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) { - const storagePerElement = dtype === 'complex64' ? 2 : 1; - const size = shape[0]; - const rank = shape.length; - if (rank === 0) { - if (dtype === 'complex64') { - const complexTuple = createComplexTuples(vals); - return [valToString(complexTuple[0], 0, dtype)]; - } - if (dtype === 'bool') { - return [boolNumToString(vals[0])]; - } - return [vals[0].toString()]; - } - if (rank === 1) { - if (size > FORMAT_LIMIT_NUM_VALS) { - const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; - let firstVals = Array.from(vals.slice(0, firstValsSize)); - let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); - if (dtype === 'complex64') { - firstVals = createComplexTuples(firstVals); - lastVals = createComplexTuples(lastVals); - } - return [ - '[' + - firstVals.map((x, i) => valToString(x, padPerCol[i], dtype)) - .join(', ') + - ', ..., ' + - lastVals - .map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype)) - .join(', ') + - ']' - ]; - } - const displayVals = dtype === 'complex64' ? createComplexTuples(vals) : - Array.from(vals); - return [ - '[' + - displayVals.map((x, i) => valToString(x, padPerCol[i], dtype)) - .join(', ') + - ']' - ]; - } - // The array is rank 2 or more. - const subshape = shape.slice(1); - const substrides = strides.slice(1); - const stride = strides[0] * storagePerElement; - const lines = []; - if (size > FORMAT_LIMIT_NUM_VALS) { - for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { - const start = i * stride; - const end = start + stride; - lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)); - } - lines.push('...'); - for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { - const start = i * stride; - const end = start + stride; - lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); - } - } - else { - for (let i = 0; i < size; i++) { - const start = i * stride; - const end = start + stride; - lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); - } - } - const sep = rank === 2 ? ',' : ''; - lines[0] = '[' + lines[0] + sep; - for (let i = 1; i < lines.length - 1; i++) { - lines[i] = ' ' + lines[i] + sep; - } - let newLineSep = ',\n'; - for (let i = 2; i < rank; i++) { - newLineSep += '\n'; - } - lines[lines.length - 1] = - ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep); - return lines; - } - function createComplexTuples(vals) { - const complexTuples = []; - for (let i = 0; i < vals.length; i += 2) { - complexTuples.push([vals[i], vals[i + 1]]); - } - return complexTuples; - } - - /** - * @license - * Copyright 2017 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * A mutable object, similar to `tf.Tensor`, that allows users to set values - * at locations before converting to an immutable `tf.Tensor`. - * - * See `tf.buffer` for creating a tensor buffer. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - class TensorBuffer { - constructor(shape, dtype, values) { - this.dtype = dtype; - this.shape = shape.slice(); - this.size = sizeFromShape(shape); - if (values != null) { - const n = values.length; - assert(n === this.size, () => `Length of values '${n}' does not match the size ` + - `inferred by the shape '${this.size}'.`); - } - if (dtype === 'complex64') { - throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` + - `a TensorBuffer for the real and imaginary parts separately and ` + - `call tf.complex(real, imag).`); - } - this.values = values || getArrayFromDType(dtype, this.size); - this.strides = computeStrides(shape); - } - /** - * Sets a value in the buffer at a given location. - * - * @param value The value to set. - * @param locs The location indices. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - set(value, ...locs) { - if (locs.length === 0) { - locs = [0]; - } - assert(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` + - `match the rank (${this.rank})`); - const index = this.locToIndex(locs); - this.values[index] = value; - } - /** - * Returns the value in the buffer at the provided location. - * - * @param locs The location indices. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - get(...locs) { - if (locs.length === 0) { - locs = [0]; - } - let i = 0; - for (const loc of locs) { - if (loc < 0 || loc >= this.shape[i]) { - const msg = `Requested out of range element at ${locs}. ` + - ` Buffer shape=${this.shape}`; - throw new Error(msg); - } - i++; - } - let index = locs[locs.length - 1]; - for (let i = 0; i < locs.length - 1; ++i) { - index += this.strides[i] * locs[i]; - } - return this.values[index]; - } - locToIndex(locs) { - if (this.rank === 0) { - return 0; - } - else if (this.rank === 1) { - return locs[0]; - } - let index = locs[locs.length - 1]; - for (let i = 0; i < locs.length - 1; ++i) { - index += this.strides[i] * locs[i]; - } - return index; - } - indexToLoc(index) { - if (this.rank === 0) { - return []; - } - else if (this.rank === 1) { - return [index]; - } - const locs = new Array(this.shape.length); - for (let i = 0; i < locs.length - 1; ++i) { - locs[i] = Math.floor(index / this.strides[i]); - index -= locs[i] * this.strides[i]; - } - locs[locs.length - 1] = index; - return locs; - } - get rank() { - return this.shape.length; - } - /** - * Creates an immutable `tf.Tensor` object from the buffer. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - toTensor() { - return trackerFn().makeTensor(this.values, this.shape, this.dtype); - } - } - // For tracking tensor creation and disposal. - let trackerFn = null; - // Used by chaining methods to call into ops. - let opHandler = null; - // Used to warn about deprecated methods. - let deprecationWarningFn = null; - // This here so that we can use this method on dev branches and keep the - // functionality at master. - // tslint:disable-next-line:no-unused-expression - [deprecationWarningFn]; - /** - * An external consumer can register itself as the tensor tracker. This way - * the Tensor class can notify the tracker for every tensor created and - * disposed. - */ - function setTensorTracker(fn) { - trackerFn = fn; - } - /** - * An external consumer can register itself as the op handler. This way the - * Tensor class can have chaining methods that call into ops via the op - * handler. - */ - function setOpHandler(handler) { - opHandler = handler; - } - /** - * Sets the deprecation warning function to be used by this file. This way the - * Tensor class can be a leaf but still use the environment. - */ - function setDeprecationWarningFn(fn) { - deprecationWarningFn = fn; - } - /** - * A `tf.Tensor` object represents an immutable, multidimensional array of - * numbers that has a shape and a data type. - * - * See `tf.tensor` for details on how to create a `tf.Tensor`. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - class Tensor { - constructor(shape, dtype, dataId, id) { - /** Whether this tensor has been globally kept. */ - this.kept = false; - this.isDisposedInternal = false; - this.shape = shape.slice(); - this.dtype = dtype || 'float32'; - this.size = sizeFromShape(shape); - this.strides = computeStrides(shape); - this.dataId = dataId; - this.id = id; - this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher'); - } - get rank() { - return this.shape.length; - } - /** - * Returns a promise of `tf.TensorBuffer` that holds the underlying data. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - async buffer() { - const vals = await this.data(); - return opHandler.buffer(this.shape, this.dtype, vals); - } - /** - * Returns a `tf.TensorBuffer` that holds the underlying data. - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - bufferSync() { - return opHandler.buffer(this.shape, this.dtype, this.dataSync()); - } - /** - * Returns the tensor data as a nested array. The transfer of data is done - * asynchronously. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - async array() { - const vals = await this.data(); - return toNestedArray(this.shape, vals); - } - /** - * Returns the tensor data as a nested array. The transfer of data is done - * synchronously. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - arraySync() { - return toNestedArray(this.shape, this.dataSync()); - } - /** - * Asynchronously downloads the values from the `tf.Tensor`. Returns a - * promise of `TypedArray` that resolves when the computation has finished. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - async data() { - this.throwIfDisposed(); - const data = trackerFn().read(this.dataId); - if (this.dtype === 'string') { - const bytes = await data; - try { - return bytes.map(b => decodeString(b)); - } - catch { - throw new Error('Failed to decode the string bytes into utf-8. ' + - 'To get the original bytes, call tensor.bytes().'); - } - } - return data; - } - /** - * Synchronously downloads the values from the `tf.Tensor`. This blocks the - * UI thread until the values are ready, which can cause performance issues. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - dataSync() { - this.throwIfDisposed(); - const data = trackerFn().readSync(this.dataId); - if (this.dtype === 'string') { - try { - return data.map(b => decodeString(b)); - } - catch { - throw new Error('Failed to decode the string bytes into utf-8. ' + - 'To get the original bytes, call tensor.bytes().'); - } - } - return data; - } - /** Returns the underlying bytes of the tensor's data. */ - async bytes() { - this.throwIfDisposed(); - const data = await trackerFn().read(this.dataId); - if (this.dtype === 'string') { - return data; - } - else { - return new Uint8Array(data.buffer); - } - } - /** - * Disposes `tf.Tensor` from memory. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - dispose() { - if (this.isDisposed) { - return; - } - trackerFn().disposeTensor(this); - this.isDisposedInternal = true; - } - get isDisposed() { - return this.isDisposedInternal; - } - throwIfDisposed() { - if (this.isDisposed) { - throw new Error(`Tensor is disposed.`); - } - } - /** - * Prints the `tf.Tensor`. See `tf.print` for details. - * - * @param verbose Whether to print verbose information about the tensor, - * including dtype and size. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - print(verbose = false) { - return opHandler.print(this, verbose); - } - /** - * Returns a copy of the tensor. See `tf.clone` for details. - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - clone() { - this.throwIfDisposed(); - return opHandler.clone(this); - } - /** - * Returns a human-readable description of the tensor. Useful for logging. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - toString(verbose = false) { - const vals = this.dataSync(); - return tensorToString(vals, this.shape, this.dtype, verbose); - } - cast(dtype) { - this.throwIfDisposed(); - return opHandler.cast(this, dtype); - } - variable(trainable = true, name, dtype) { - this.throwIfDisposed(); - return trackerFn().makeVariable(this, trainable, name, dtype); - } - } - Object.defineProperty(Tensor, Symbol.hasInstance, { - value: (instance) => { - // Implementation note: we should use properties of the object that will be - // defined before the constructor body has finished executing (methods). - // This is because when this code is transpiled by babel, babel will call - // classCallCheck before the constructor body is run. - // See https://github.com/tensorflow/tfjs/issues/3384 for backstory. - return !!instance && instance.data != null && instance.dataSync != null && - instance.throwIfDisposed != null; - } - }); - /** - * A mutable `tf.Tensor`, useful for persisting state, e.g. for training. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - class Variable extends Tensor { - constructor(initialValue, trainable, name, tensorId) { - super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId); - this.trainable = trainable; - this.name = name; - } - /** - * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have - * the same shape and dtype as the old `tf.Tensor`. - * - * @param newValue New tensor to be assigned to this variable. - * - * @doc {heading: 'Tensors', subheading: 'Classes'} - */ - assign(newValue) { - if (newValue.dtype !== this.dtype) { - throw new Error(`dtype of the new value (${newValue.dtype}) and ` + - `previous value (${this.dtype}) must match`); - } - if (!arraysEqual(newValue.shape, this.shape)) { - throw new Error(`shape of the new value (${newValue.shape}) and ` + - `previous value (${this.shape}) must match`); - } - trackerFn().disposeTensor(this); - this.dataId = newValue.dataId; - trackerFn().incRef(this, null /* backend */); - } - dispose() { - trackerFn().disposeVariable(this); - this.isDisposedInternal = true; - } - } - Object.defineProperty(Variable, Symbol.hasInstance, { - value: (instance) => { - return instance instanceof Tensor && instance.assign != null && - instance.assign instanceof Function; - } - }); - - /** - * @license - * Copyright 2017 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - (function (Rank) { - Rank["R0"] = "R0"; - Rank["R1"] = "R1"; - Rank["R2"] = "R2"; - Rank["R3"] = "R3"; - Rank["R4"] = "R4"; - Rank["R5"] = "R5"; - Rank["R6"] = "R6"; - })(exports.Rank || (exports.Rank = {})); - // Looks for upcasting types. Used, for example, in operations with mixed dtype - // inputs. - var UpcastInt32AndMap; - (function (UpcastInt32AndMap) { - UpcastInt32AndMap["float32"] = "float32"; - UpcastInt32AndMap["int32"] = "int32"; - UpcastInt32AndMap["bool"] = "int32"; - UpcastInt32AndMap["complex64"] = "complex64"; - })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); - var UpcastBoolAndMap; - (function (UpcastBoolAndMap) { - UpcastBoolAndMap["float32"] = "float32"; - UpcastBoolAndMap["int32"] = "int32"; - UpcastBoolAndMap["bool"] = "bool"; - UpcastBoolAndMap["complex64"] = "complex64"; - })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); - var UpcastFloat32AndMap; - (function (UpcastFloat32AndMap) { - UpcastFloat32AndMap["float32"] = "float32"; - UpcastFloat32AndMap["int32"] = "float32"; - UpcastFloat32AndMap["bool"] = "float32"; - UpcastFloat32AndMap["complex64"] = "complex64"; - })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); - var UpcastComplex64AndMap; - (function (UpcastComplex64AndMap) { - UpcastComplex64AndMap["float32"] = "complex64"; - UpcastComplex64AndMap["int32"] = "complex64"; - UpcastComplex64AndMap["bool"] = "complex64"; - UpcastComplex64AndMap["complex64"] = "complex64"; - })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); - const upcastTypeMap = { - 'float32': UpcastFloat32AndMap, - 'int32': UpcastInt32AndMap, - 'bool': UpcastBoolAndMap, - 'complex64': UpcastComplex64AndMap - }; - function upcastType(typeA, typeB) { - if (typeA === 'string' || typeB === 'string') { - if (typeA === 'string' && typeB === 'string') { - return 'string'; - } - throw new Error(`Can not upcast ${typeA} with ${typeB}`); - } - return upcastTypeMap[typeA][typeB]; - } - /** Returns the output type after summation. */ - function sumOutType(type) { - return upcastType(type, 'int32'); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function makeTypesMatch(a, b) { - if (a.dtype === b.dtype) { - return [a, b]; - } - const dtype = upcastType(a.dtype, b.dtype); - return [a.cast(dtype), b.cast(dtype)]; - } - function assertTypesMatch(a, b) { - assert(a.dtype === b.dtype, () => `The dtypes of the first(${a.dtype}) and` + - ` second(${b.dtype}) input must match`); - } - function isTensorInList(tensor, tensorList) { - return tensorList.some(x => x.id === tensor.id); - } - /** - * Extracts any `Tensor`s found within the provided object. - * - * @param container an object that may be a `Tensor` or may directly contain - * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it - * is safe to pass any object here, except that `Promise`s are not - * supported. - * @returns An array of `Tensors` found within the passed object. If the - * argument is simply a `Tensor', a list containing that `Tensor` is - * returned. If the object is not a `Tensor` or does not - * contain `Tensors`, an empty list is returned. - */ - function getTensorsInContainer(result) { - const list = []; - const seen = new Set(); - walkTensorContainer(result, list, seen); - return list; - } - function walkTensorContainer(container, list, seen) { - if (container == null) { - return; - } - if (container instanceof Tensor) { - list.push(container); - return; - } - if (!isIterable(container)) { - return; - } - // Iteration over keys works also for arrays. - const iterable = container; - for (const k in iterable) { - const val = iterable[k]; - if (!seen.has(val)) { - seen.add(val); - walkTensorContainer(val, list, seen); - } - } - } - // tslint:disable-next-line:no-any - function isIterable(obj) { - return Array.isArray(obj) || typeof obj === 'object'; - } - - var tensor_util = /*#__PURE__*/Object.freeze({ - __proto__: null, - makeTypesMatch: makeTypesMatch, - assertTypesMatch: assertTypesMatch, - isTensorInList: isTensorInList, - getTensorsInContainer: getTensorsInContainer - }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - class EngineState { - constructor() { - // Public since optimizers will use it. - this.registeredVariables = {}; - this.nextTapeNodeId = 0; - this.numBytes = 0; - this.numTensors = 0; - this.numStringTensors = 0; - this.numDataBuffers = 0; - // Number of nested tf.grad() statements when computing higher-order - // gradients. E.g. `1` for first-order gradients and `2` for second-order - // gradients. Used to track if the tape should be removed after a backprop. - this.gradientDepth = 0; - // Number of nested kernel calls. When kernel depth is greater than 1, we turn - // off the tape. - this.kernelDepth = 0; - this.scopeStack = []; - /** - * Keeps track of the number of data moves during a kernel execution. We - * maintain a stack since kernels can call other kernels, recursively. - */ - this.numDataMovesStack = []; - this.nextScopeId = 0; - this.tensorInfo = new WeakMap(); - this.profiling = false; - this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null }; - } - dispose() { - for (const variableName in this.registeredVariables) { - this.registeredVariables[variableName].dispose(); - } - } - } - class Engine { - constructor(ENV) { - this.ENV = ENV; - this.registry = {}; - this.registryFactory = {}; - this.pendingBackendInitId = 0; - this.state = new EngineState(); - } - async ready() { - if (this.pendingBackendInit != null) { - return this.pendingBackendInit.then(() => { }); - } - if (this.backendInstance != null) { - return; - } - const sortedBackends = this.getSortedBackends(); - for (let i = 0; i < sortedBackends.length; i++) { - const backendName = sortedBackends[i]; - const success = await this.initializeBackend(backendName).success; - if (success) { - await this.setBackend(backendName); - return; - } - } - throw new Error(`Could not initialize any backends, all backend initializations ` + - `failed.`); - } - get backend() { - if (this.pendingBackendInit != null) { - throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` + - `sure to await tf.ready() or await tf.setBackend() before calling ` + - `other methods`); - } - if (this.backendInstance == null) { - const { name, asyncInit } = this.initializeBackendsAndReturnBest(); - if (asyncInit) { - throw new Error(`The highest priority backend '${name}' has not yet been ` + - `initialized. Make sure to await tf.ready() or ` + - `await tf.setBackend() before calling other methods`); - } - this.setBackend(name); - } - return this.backendInstance; - } - backendNames() { - return Object.keys(this.registryFactory); - } - findBackend(backendName) { - if (!(backendName in this.registry)) { - // If the backend hasn't been initialized but we have a registry entry for - // it, initialize it and return it. - if (backendName in this.registryFactory) { - const { asyncInit } = this.initializeBackend(backendName); - if (asyncInit) { - // Backend is not ready yet. - return null; - } - } - else { - return null; - } - } - return this.registry[backendName]; - } - findBackendFactory(backendName) { - if (!(backendName in this.registryFactory)) { - return null; - } - return this.registryFactory[backendName].factory; - } - registerBackend(backendName, factory, priority = 1) { - if (backendName in this.registryFactory) { - console.warn(`${backendName} backend was already registered. ` + - `Reusing existing backend factory.`); - return false; - } - this.registryFactory[backendName] = { factory, priority }; - return true; - } - async setBackend(backendName) { - if (this.registryFactory[backendName] == null) { - throw new Error(`Backend name '${backendName}' not found in registry`); - } - this.backendName = backendName; - if (this.registry[backendName] == null) { - this.backendInstance = null; - const { success, asyncInit } = this.initializeBackend(backendName); - const result = asyncInit ? await success : success; - if (!result) { - return false; - } - } - this.backendInstance = this.registry[backendName]; - this.setupRegisteredKernels(); - // Reset the profiler. - this.profiler = new Profiler(this.backendInstance); - return true; - } - setupRegisteredKernels() { - const kernels = getKernelsForBackend(this.backendName); - kernels.forEach(kernel => { - if (kernel.setupFunc != null) { - kernel.setupFunc(this.backendInstance); - } - }); - } - disposeRegisteredKernels(backendName) { - const kernels = getKernelsForBackend(backendName); - kernels.forEach(kernel => { - if (kernel.disposeFunc != null) { - kernel.disposeFunc(this.registry[backendName]); - } - }); - } - /** - * Initializes a backend by looking up the backend name in the factory - * registry and calling the factory method. Returns a boolean representing - * whether the initialization of the backend suceeded. Throws an error if - * there is no backend in the factory registry. - */ - initializeBackend(backendName) { - const registryFactoryEntry = this.registryFactory[backendName]; - if (registryFactoryEntry == null) { - throw new Error(`Cannot initialize backend ${backendName}, no registration found.`); - } - try { - const backend = registryFactoryEntry.factory(); - /* Test if the factory returns a promise. - Done in a more liberal way than - previous 'Promise.resolve(backend)===backend' - as we needed to account for custom Promise - implementations (e.g. Angular) */ - if (backend && !(backend instanceof KernelBackend) - && typeof backend.then === 'function') { - const promiseId = ++this.pendingBackendInitId; - const success = backend - .then(backendInstance => { - // Outdated promise. Another backend was set in the meantime. - if (promiseId < this.pendingBackendInitId) { - return false; - } - this.registry[backendName] = backendInstance; - this.pendingBackendInit = null; - return true; - }) - .catch(err => { - // Outdated promise. Another backend was set in the meantime. - if (promiseId < this.pendingBackendInitId) { - return false; - } - this.pendingBackendInit = null; - console.warn(`Initialization of backend ${backendName} failed`); - console.warn(err.stack || err.message); - return false; - }); - this.pendingBackendInit = success; - return { success, asyncInit: true }; - } - else { - this.registry[backendName] = backend; - return { success: true, asyncInit: false }; - } - } - catch (err) { - console.warn(`Initialization of backend ${backendName} failed`); - console.warn(err.stack || err.message); - return { success: false, asyncInit: false }; - } - } - removeBackend(backendName) { - if (!(backendName in this.registryFactory)) { - throw new Error(`${backendName} backend not found in registry`); - } - if (this.backendName === backendName && this.pendingBackendInit != null) { - // There is a pending promise of the backend we want to remove. Make it - // obsolete. - this.pendingBackendInitId++; - } - if (backendName in this.registry) { - this.disposeRegisteredKernels(backendName); - this.registry[backendName].dispose(); - delete this.registry[backendName]; - } - delete this.registryFactory[backendName]; - // Unset the backend if it is active. - if (this.backendName === backendName) { - this.pendingBackendInit = null; - this.backendName = null; - this.backendInstance = null; - } - } - getSortedBackends() { - if (Object.keys(this.registryFactory).length === 0) { - throw new Error('No backend found in registry.'); - } - return Object.keys(this.registryFactory).sort((a, b) => { - // Highest priority comes first. - return this.registryFactory[b].priority - - this.registryFactory[a].priority; - }); - } - initializeBackendsAndReturnBest() { - const sortedBackends = this.getSortedBackends(); - for (let i = 0; i < sortedBackends.length; i++) { - const backendName = sortedBackends[i]; - const { success, asyncInit } = this.initializeBackend(backendName); - if (asyncInit || success) { - return { name: backendName, asyncInit }; - } - } - throw new Error(`Could not initialize any backends, all backend initializations ` + - `failed.`); - } - moveData(backend, dataId) { - const info = this.state.tensorInfo.get(dataId); - const srcBackend = info.backend; - const values = this.readSync(dataId); - // Delete the tensor from the old backend and move it to the new - // backend. - srcBackend.disposeData(dataId); - info.backend = backend; - backend.move(dataId, values, info.shape, info.dtype); - if (this.shouldCheckForMemLeaks()) { - // Track the number of moves during a kernel execution to correctly - // detect memory leaks. - this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; - } - } - tidy(nameOrFn, fn) { - let name = null; - if (fn == null) { - // Called with only 1 argument. - if (typeof nameOrFn !== 'function') { - throw new Error('Please provide a function to tidy()'); - } - fn = nameOrFn; - } - else { - // Called with 2 arguments. - if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { - throw new Error('When calling with two arguments, the first argument ' + - 'to tidy() must be a string'); - } - if (typeof fn !== 'function') { - throw new Error('When calling with two arguments, the 2nd argument ' + - 'to tidy() must be a function'); - } - name = nameOrFn; - // TODO(nsthorat,smilkov): Do operation logging and performance - // profiling. - } - let result; - return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => { - result = fn(); - if (result instanceof Promise) { - console.error('Cannot return a Promise inside of tidy.'); - } - return result; - }); - } - scopedRun(start, end, f) { - start(); - try { - const res = f(); - end(); - return res; - } - catch (ex) { - end(); - throw ex; - } - } - nextTensorId() { - return Engine.nextTensorId++; - } - nextVariableId() { - return Engine.nextVariableId++; - } - /** - * This method is called instead of the public-facing tensor.clone() when - * saving a tensor for backwards pass. It makes sure to add the clone - * operation to the tape regardless of being called inside a kernel - * execution. - * - * This method will go away once all kernels are modularized since we won't - * need to turn off the tape inside runKernel(). - */ - clone(x) { - const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype); - const inputs = { x }; - const grad = (dy) => ({ - x: () => { - const dtype = 'float32'; - const gradInputs = { x: dy }; - const attrs = { dtype }; - return ENGINE.runKernelFunc(backend => backend.cast(dy, dtype), gradInputs, null /* grad */, Cast, attrs); - } - }); - const saved = []; - this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); - return y; - } - /** - * Execute a kernel with the given name and return the output tensor. - * - * @param kernelName The name of the kernel to execute. - * @param inputs A map of input names to tensors. - * @param attrs A map of attribute names to their values. An attribute is a - * primitive (non-tensor) input to the kernel. - * @param inputsToSave A list of tensors, inputs to save for the backprop - * computation. - * @param outputsToSave A list of booleans, specifying which output to save - * for the backprop computation. These are booleans since the output - * tensors are not visible to the user. - */ - runKernel(kernelName, inputs, attrs, inputsToSave, outputsToSave) { - const forwardFunc = null; - const backwardsFunc = null; - // Call runKernel as a stop-gap until we modularize all kernels. - // Once we modularize all kernels, we will remove the existing - // `runKernelFunc`. - return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave); - } - shouldCheckForMemLeaks() { - return this.ENV.getBool('IS_TEST'); - } - checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) { - const numDataIdsAfter = this.backend.numDataIds(); - // Count the number of data ids associated with the result of the kernel. - let numOutputDataIds = 0; - outInfos.forEach(info => { - // Complex numbers allocate 3 data ids, one for 'real', one for - // 'imaginary', and one for the container that holds the former two. - numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); - }); - // Account for the number of moves during kernel execution. A "data move" - // can happen in the middle of a kernel execution, placing a new (key,value) - // pair in the data storage. Since data moves have net zero effect (we - // always remove the data from the old backend), we have to cancel them out - // when detecting memory leaks. - const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; - const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; - if (dataIdsLeaked > 0) { - throw new Error(`Backend '${this.backendName}' has an internal memory leak ` + - `(${dataIdsLeaked} data ids) after running '${kernelName}'`); - } - } - /** - * @deprecated Use `runKernel` for newly added kernels. Keep using this method - * only for kernels that are not yet fully modularized. - */ - runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) { - let outputs; - let saved = []; - const isTapeOn = this.isTapeOn(); - if (kernelName == null) { - kernelName = - this.state.activeScope != null ? this.state.activeScope.name : ''; - } - const startingBytecount = this.state.numBytes; - const startingNumTensors = this.state.numTensors; - if (this.shouldCheckForMemLeaks()) { - this.state.numDataMovesStack.push(0); - } - let kernelFunc; - const kernel = getKernel(kernelName, this.backendName); - let out; - if (kernel != null) { - kernelFunc = () => { - const numDataIdsBefore = this.backend.numDataIds(); - out = kernel.kernelFunc({ inputs, attrs, backend: this.backend }); - const outInfos = Array.isArray(out) ? out : [out]; - if (this.shouldCheckForMemLeaks()) { - this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); - } - const outTensors = outInfos.map(({ dataId, shape, dtype }) => this.makeTensorFromDataId(dataId, shape, dtype)); - // Save the inputs and outputs. - // Do not save unless we are recording to the tape. Otherwise it would - // cause a mem leak since we would never run backprop, which disposes - // the kept tensors. - if (isTapeOn) { - let tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors); - if (tensorsToSave == null) { - // Fallback for ops that call runKernelFunc and pass in - // inputsToSave and outputsToSave. Currently this is the set of ops - // with kernel support in the WASM backend. Once those ops and - // respective gradients are modularised we can remove this path. - if (outputsToSave == null) { - outputsToSave = []; - } - const outsToSave = outTensors.filter((_, i) => outputsToSave[i]); - tensorsToSave = (inputsToSave || []).slice().concat(outsToSave); - } - saved = this.saveTensorsForBackwardMode(tensorsToSave); - } - return outTensors; - }; - } - else { - const saveFunc = (tensors) => { - // Do not save unless we are recording to the tape. Otherwise it would - // cause a mem leak since we would never run backprop, which disposes - // the kept tensors. - if (!isTapeOn) { - return; - } - saved = tensors.map(tensor => this.keep(this.clone(tensor))); - }; - kernelFunc = () => { - const numDataIdsBefore = this.backend.numDataIds(); - out = this.tidy(() => forwardFunc(this.backend, saveFunc)); - const outs = (Array.isArray(out) ? out : [out]); - if (this.shouldCheckForMemLeaks()) { - this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs); - } - return outs; - }; - } - // Stop recording to a tape when running a kernel. - let kernelProfile; - this.scopedRun(() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => { - if (!this.ENV.getBool('DEBUG') && !this.state.profiling) { - outputs = kernelFunc(); - } - else { - kernelProfile = this.profiler.profileKernel(kernelName, inputs, () => kernelFunc()); - if (this.ENV.getBool('DEBUG')) { - this.profiler.logKernelProfile(kernelProfile); - } - outputs = kernelProfile.outputs; - } - }); - if (isTapeOn) { - this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs); - } - if (this.state.profiling) { - this.state.activeProfile.kernels.push({ - name: kernelName, - bytesAdded: this.state.numBytes - startingBytecount, - totalBytesSnapshot: this.state.numBytes, - tensorsAdded: this.state.numTensors - startingNumTensors, - totalTensorsSnapshot: this.state.numTensors, - inputShapes: Object.keys(inputs).map(key => inputs[key] != null ? inputs[key].shape : null), - outputShapes: outputs.map(item => item.shape), - kernelTimeMs: kernelProfile.timeMs, - extraInfo: kernelProfile.extraInfo - }); - } - return (Array.isArray(out) ? outputs : outputs[0]); - } - /** - * Saves tensors used in forward mode for use in backward mode. - * - * @param tensors the list of tensors to save. - */ - saveTensorsForBackwardMode(tensors) { - const saved = tensors.map(tensor => this.keep(this.clone(tensor))); - return saved; - } - /** - * Returns a list of tensors to save for a given gradient calculation. - * - * Returns undefined if their is no registered gradient for this kernel in the - * gradient registry. - * - * @param kernelName name of kernel to look up gradient for. - * @param inputs a map of input tensors. - * @param outputs an array of output tensors from forward mode of kernel. - */ - getTensorsForGradient(kernelName, inputs, outputs) { - const gradConfig = getGradient(kernelName); - if (gradConfig != null) { - const inputsToSave = gradConfig.inputsToSave || []; - const outputsToSave = gradConfig.outputsToSave || []; - // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs - // specified in inputsToSave will be saved. - let inputTensorsToSave; - if (gradConfig.saveAllInputs) { - assert(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.'); - inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]); - } - else { - inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]); - } - const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]); - return inputTensorsToSave.concat(outputTensorsToSave); - } - // TODO(yassogba) throw exception here once all runkernelFunc calls with - // inputsToSave/outputsToSave are removed - return null; - } - /** - * Internal method used by public APIs for tensor creation. Makes a new - * tensor with the provided shape, dtype and values. It always - * creates a new data id and writes the values to the underlying backend. - */ - makeTensor(values, shape, dtype, backend) { - if (values == null) { - throw new Error('Values passed to engine.makeTensor() are null'); - } - dtype = dtype || 'float32'; - backend = backend || this.backend; - let backendVals = values; - if (dtype === 'string' && isString(values[0])) { - backendVals = values.map(d => encodeString(d)); - } - const dataId = backend.write(backendVals, shape, dtype); - const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); - this.incRef(t, backend); - // Count bytes for string tensors. - if (dtype === 'string') { - const info = this.state.tensorInfo.get(dataId); - const newBytes = bytesFromStringArray(backendVals); - this.state.numBytes += newBytes - info.bytes; - info.bytes = newBytes; - } - return t; - } - /** - * Internal method used by backends. Makes a new tensor - * that is a wrapper around an existing data id. It doesn't create - * a new data id, only increments the ref count used in memory tracking. - */ - makeTensorFromDataId(dataId, shape, dtype, backend) { - dtype = dtype || 'float32'; - const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); - this.incRef(t, backend); - return t; - } - makeVariable(initialValue, trainable = true, name, dtype) { - name = name || this.nextVariableId().toString(); - if (dtype != null && dtype !== initialValue.dtype) { - initialValue = initialValue.cast(dtype); - } - const v = new Variable(initialValue, trainable, name, this.nextTensorId()); - if (this.state.registeredVariables[v.name] != null) { - throw new Error(`Variable with name ${v.name} was already registered`); - } - this.state.registeredVariables[v.name] = v; - this.incRef(v, this.backend); - return v; - } - incRef(a, backend) { - const refCount = this.state.tensorInfo.has(a.dataId) ? - this.state.tensorInfo.get(a.dataId).refCount : - 0; - this.state.numTensors++; - if (a.dtype === 'string') { - this.state.numStringTensors++; - } - if (refCount === 0) { - this.state.numDataBuffers++; - // Bytes for complex numbers are counted by their components. Bytes for - // string tensors are counted when writing values. - let bytes = 0; - if (a.dtype !== 'complex64' && a.dtype !== 'string') { - bytes = a.size * bytesPerElement(a.dtype); - } - this.state.tensorInfo.set(a.dataId, { - backend: backend || this.backend, - dtype: a.dtype, - shape: a.shape, - bytes, - refCount: 0 - }); - this.state.numBytes += bytes; - } - this.state.tensorInfo.get(a.dataId).refCount++; - if (!(a instanceof Variable)) { - this.track(a); - } - } - disposeTensor(a) { - if (!this.state.tensorInfo.has(a.dataId)) { - return; - } - this.state.numTensors--; - if (a.dtype === 'string') { - this.state.numStringTensors--; - } - const info = this.state.tensorInfo.get(a.dataId); - const refCount = info.refCount; - if (refCount <= 1) { - // Don't count bytes for complex numbers as they are counted by their - // components. - if (a.dtype !== 'complex64') { - this.state.numBytes -= info.bytes; - } - this.state.numDataBuffers--; - info.backend.disposeData(a.dataId); - this.state.tensorInfo.delete(a.dataId); - } - else { - this.state.tensorInfo.get(a.dataId).refCount--; - } - // TODO(nsthorat): Construct an error and save the stack trace for - // debugging when in debug mode. Creating a stack trace is too expensive - // to do unconditionally. - } - disposeVariables() { - for (const varName in this.state.registeredVariables) { - const v = this.state.registeredVariables[varName]; - this.disposeVariable(v); - } - } - disposeVariable(v) { - this.disposeTensor(v); - if (this.state.registeredVariables[v.name] != null) { - delete this.state.registeredVariables[v.name]; - } - } - memory() { - const info = this.backend.memory(); - info.numTensors = this.state.numTensors; - info.numDataBuffers = this.state.numDataBuffers; - info.numBytes = this.state.numBytes; - if (this.state.numStringTensors > 0) { - info.unreliable = true; - if (info.reasons == null) { - info.reasons = []; - } - info.reasons.push('Memory usage by string tensors is approximate ' + - '(2 bytes per character)'); - } - return info; - } - async profile(query) { - this.state.profiling = true; - const startBytes = this.state.numBytes; - const startNumTensors = this.state.numTensors; - this.state.activeProfile.kernels = []; - this.state.activeProfile.result = await query(); - this.state.profiling = false; - this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot)); - this.state.activeProfile.newBytes = this.state.numBytes - startBytes; - this.state.activeProfile.newTensors = - this.state.numTensors - startNumTensors; - for (const kernel of this.state.activeProfile.kernels) { - kernel.kernelTimeMs = await kernel.kernelTimeMs; - kernel.extraInfo = await kernel.extraInfo; - } - return this.state.activeProfile; - } - isTapeOn() { - return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; - } - addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) { - const tapeNode = { id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved }; - const gradConfig = getGradient(kernelName); - if (gradConfig != null) { - gradientsFunc = gradConfig.gradFunc; - } - if (gradientsFunc != null) { - tapeNode.gradient = (dys) => { - // TODO(smilkov): To optimize back-prop, pass dys that are not used in - // the backprop graph to the user as null instead of zeros - dys = dys.map((dy, i) => { - if (dy == null) { - const output = outputs[i]; - const vals = makeZerosTypedArray(output.size, output.dtype); - return this.makeTensor(vals, output.shape, output.dtype); - } - return dy; - }); - // Grad functions of ops with single outputs expect a dy, while ops - // with multiple outputs expect dys (array of dy). - return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); - }; - } - this.state.activeTape.push(tapeNode); - } - keep(result) { - result.kept = true; - return result; - } - startTape() { - if (this.state.gradientDepth === 0) { - this.state.activeTape = []; - } - this.state.gradientDepth++; - } - endTape() { - this.state.gradientDepth--; - } - /** - * Start a scope. Use this with endScope() to achieve the same functionality - * as scope() without the need for a function closure. - */ - startScope(name) { - const scopeInfo = { - track: [], - name: 'unnamed scope', - id: this.state.nextScopeId++ - }; - if (name) { - scopeInfo.name = name; - } - this.state.scopeStack.push(scopeInfo); - this.state.activeScope = scopeInfo; - } - /** - * End a scope. Use this with startScope() to achieve the same functionality - * as scope() without the need for a function closure. - */ - endScope(result) { - const tensorsToTrackInParent = getTensorsInContainer(result); - const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id)); - // Dispose the arrays tracked in this scope. - for (let i = 0; i < this.state.activeScope.track.length; i++) { - const tensor = this.state.activeScope.track[i]; - if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) { - tensor.dispose(); - } - } - const oldScope = this.state.scopeStack.pop(); - this.state.activeScope = this.state.scopeStack.length === 0 ? - null : - this.state.scopeStack[this.state.scopeStack.length - 1]; - // Track the current result in the parent scope. - tensorsToTrackInParent.forEach(tensor => { - // Only track the tensor if was allocated in the inner scope and is not - // globally kept. - if (!tensor.kept && tensor.scopeId === oldScope.id) { - this.track(tensor); - } - }); - } - /** - * Returns gradients of `f` with respect to each of the `xs`. The gradients - * returned are of the same length as `xs`, but some might be null if `f` - * was not a function of that `x`. It also takes optional dy to multiply the - * gradient, which defaults to `1`. - */ - gradients(f, xs, dy, allowNoGradients = false) { - assert(xs.length > 0, () => 'gradients() received an empty list of xs.'); - if (dy != null && dy.dtype !== 'float32') { - throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`); - } - const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy('forward', f)); - assert(y instanceof Tensor, () => 'The result y returned by f() must be a tensor.'); - // Filter out the nodes that don't connect x => y. - const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); - if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { - throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + - 'that the f you passed encloses all operations that lead from x ' + - 'to y.'); - } - return this.tidy('backward', () => { - const accumulatedGradientMap = {}; - accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy; - // Backprop gradients through the filtered nodes. - backpropagateGradients(accumulatedGradientMap, filteredTape, - // Pass the tidy function to avoid circular dep with `tape.ts`. - f => this.tidy(f), - // Pass an add function to avoide a circular dep with `tape.ts`. - add); - const grads = xs.map(x => accumulatedGradientMap[x.id]); - if (this.state.gradientDepth === 0) { - // This means that we are not computing higher-order gradients - // and can clean up the tape. - this.state.activeTape.forEach(node => { - for (const tensor of node.saved) { - tensor.dispose(); - } - }); - this.state.activeTape = null; - } - return { value: y, grads }; - }); - } - customGrad(f) { - assert(isFunction(f), () => 'The f passed in customGrad(f) must be a function.'); - return (...inputs) => { - assert(inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' + - 'tensors'); - let res; - const inputMap = {}; - inputs.forEach((input, i) => { - inputMap[i] = input; - }); - return this.runKernelFunc((_, save) => { - res = f(...[...inputs, save]); - assert(res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' + - 'object where `obj.value` is a tensor'); - assert(isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' + - 'object where `obj.gradFunc` is a function.'); - return res.value; - }, inputMap, (dy, saved) => { - const gradRes = res.gradFunc(dy, saved); - const grads = Array.isArray(gradRes) ? gradRes : [gradRes]; - assert(grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' + - 'object where `obj.gradFunc` is a function that returns ' + - 'the same number of tensors as inputs passed to f(...).'); - assert(grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' + - 'object where `obj.gradFunc` is a function that returns ' + - 'a list of only tensors.'); - const gradMap = {}; - grads.forEach((grad, i) => { - gradMap[i] = () => grad; - }); - return gradMap; - }); - }; - } - readSync(dataId) { - // Route the read to the correct backend. - const info = this.state.tensorInfo.get(dataId); - return info.backend.readSync(dataId); - } - read(dataId) { - // Route the read to the correct backend. - const info = this.state.tensorInfo.get(dataId); - return info.backend.read(dataId); - } - async time(query) { - const start = now(); - const timingInfo = await this.backend.time(query); - timingInfo.wallMs = now() - start; - return timingInfo; - } - /** - * Tracks a Tensor in the current scope to be automatically cleaned up - * when the current scope ends, and returns the value. - * - * @param result The Tensor to track in the current scope. - */ - track(result) { - if (this.state.activeScope != null) { - result.scopeId = this.state.activeScope.id; - this.state.activeScope.track.push(result); - } - return result; - } - get registeredVariables() { - return this.state.registeredVariables; - } - /** - * Resets the engine state. Removes all backends but does not remove - * registered backend factories. - */ - reset() { - // Make any pending promise obsolete. - this.pendingBackendInitId++; - this.state.dispose(); - this.ENV.reset(); - this.state = new EngineState(); - for (const backendName in this.registry) { - this.disposeRegisteredKernels(backendName); - this.registry[backendName].dispose(); - delete this.registry[backendName]; - } - this.backendName = null; - this.backendInstance = null; - this.pendingBackendInit = null; - } - } - Engine.nextTensorId = 0; - Engine.nextVariableId = 0; - function ones(shape) { - const values = makeOnesTypedArray(sizeFromShape(shape), 'float32'); - return ENGINE.makeTensor(values, shape, 'float32'); - } - function getOrMakeEngine() { - const ns = getGlobalNamespace(); - if (ns._tfengine == null) { - const environment = new Environment(ns); - ns._tfengine = new Engine(environment); - } - setEnvironmentGlobal(ns._tfengine.ENV); - // Tell the current tensor interface that the global engine is responsible - // for tracking. - setTensorTracker(() => ns._tfengine); - return ns._tfengine; - } - const ENGINE = getOrMakeEngine(); - /** - * A implementation of the add op for use within engine and tape. - * - * This allows us to avoid a circular dependency between add.ts and engine. - * It is exported to be available in tape tests. - */ - function add(a, b) { - // We duplicate Add here to avoid a circular dependency with add.ts. - const inputs = { a, b }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.add(a, b); - save([a, b]); - return res; - }, inputs, null /* gradient */, Add); - } - - /** - * @license - * Copyright 2017 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - // tslint:disable-next-line:no-any - function _isNavigatorDefined() { - return typeof navigator !== 'undefined' && navigator != null; - } - function isMobile() { - if (_isNavigatorDefined()) { - // tslint:disable-next-line:no-any - const a = navigator.userAgent || navigator.vendor || window.opera; - // tslint:disable-next-line:max-line-length - return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i - .test(a) || - // tslint:disable-next-line:max-line-length - /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i - .test(a.substr(0, 4)); - } - return false; - } - function isBrowser() { - return (typeof window !== 'undefined' && window.document != null) || - //@ts-ignore - (typeof WorkerGlobalScope !== 'undefined'); - } - - var device_util = /*#__PURE__*/Object.freeze({ - __proto__: null, - isMobile: isMobile, - isBrowser: isBrowser - }); - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const ENV = env(); - /** - * This file contains environment-related flag registrations. - */ - /** Whether to enable debug mode. */ - ENV.registerFlag('DEBUG', () => false, debugValue => { - if (debugValue) { - console.warn('Debugging mode is ON. The output of every math call will ' + - 'be downloaded to CPU and checked for NaNs. ' + - 'This significantly impacts performance.'); - } - }); - /** Whether we are in a browser (as versus, say, node.js) environment. */ - ENV.registerFlag('IS_BROWSER', () => isBrowser()); - /** Whether we are in a browser (as versus, say, node.js) environment. */ - ENV.registerFlag('IS_NODE', () => (typeof process !== 'undefined') && - (typeof process.versions !== 'undefined') && - (typeof process.versions.node !== 'undefined')); - /** Whether this browser is Chrome. */ - ENV.registerFlag('IS_CHROME', () => typeof navigator !== 'undefined' && navigator != null && - navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && - /Google Inc/.test(navigator.vendor)); - /** - * True when the environment is "production" where we disable safety checks - * to gain performance. - */ - ENV.registerFlag('PROD', () => false); - /** - * Whether to do sanity checks when inferring a shape from user-provided - * values, used when creating a new tensor. - */ - ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => ENV.getBool('DEBUG')); - /** Whether deprecation warnings are enabled. */ - ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', () => true); - /** True if running unit tests. */ - ENV.registerFlag('IS_TEST', () => false); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function inferShape(val, dtype) { - let firstElem = val; - if (isTypedArray(val)) { - return dtype === 'string' ? [] : [val.length]; - } - if (!Array.isArray(val)) { - return []; // Scalar. - } - const shape = []; - while (Array.isArray(firstElem) || - isTypedArray(firstElem) && dtype !== 'string') { - shape.push(firstElem.length); - firstElem = firstElem[0]; - } - if (Array.isArray(val) && - env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { - deepAssertShapeConsistency(val, shape, []); - } - return shape; - } - function deepAssertShapeConsistency(val, shape, indices) { - indices = indices || []; - if (!(Array.isArray(val)) && !isTypedArray(val)) { - assert(shape.length === 0, () => `Element arr[${indices.join('][')}] is a primitive, ` + - `but should be an array/TypedArray of ${shape[0]} elements`); - return; - } - assert(shape.length > 0, () => `Element arr[${indices.join('][')}] should be a primitive, ` + - `but is an array of ${val.length} elements`); - assert(val.length === shape[0], () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` + - `elements, but has ${val.length} elements`); - const subShape = shape.slice(1); - for (let i = 0; i < val.length; ++i) { - deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); - } - } - function assertDtype(expectedDtype, actualDType, argName, functionName) { - if (expectedDtype == null) { - return; - } - if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || - expectedDtype === 'numeric' && actualDType === 'string') { - throw new Error(`Argument '${argName}' passed to '${functionName}' must ` + - `be ${expectedDtype} tensor, but got ${actualDType} tensor`); - } - } - function convertToTensor(x, argName, functionName, parseAsDtype = 'numeric') { - if (x instanceof Tensor) { - assertDtype(parseAsDtype, x.dtype, argName, functionName); - return x; - } - let inferredDtype = inferDtype(x); - // If the user expects a bool/int/float, use that info to update the - // inferredDtype when it is not a string. - if (inferredDtype !== 'string' && - ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) { - inferredDtype = parseAsDtype; - } - assertDtype(parseAsDtype, inferredDtype, argName, functionName); - if ((x == null) || - (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && - typeof x !== 'boolean' && typeof x !== 'string')) { - const type = x == null ? 'null' : x.constructor.name; - throw new Error(`Argument '${argName}' passed to '${functionName}' must be a ` + - `Tensor or TensorLike, but got '${type}'`); - } - const inferredShape = inferShape(x, inferredDtype); - if (!isTypedArray(x) && !Array.isArray(x)) { - x = [x]; - } - const skipTypedArray = true; - const values = inferredDtype !== 'string' ? - toTypedArray(x, inferredDtype) : - flatten(x, [], skipTypedArray); - return ENGINE.makeTensor(values, inferredShape, inferredDtype); - } - function convertToTensorArray(arg, argName, functionName, parseAsDtype = 'numeric') { - if (!Array.isArray(arg)) { - throw new Error(`Argument ${argName} passed to ${functionName} must be a ` + - '`Tensor[]` or `TensorLike[]`'); - } - const tensors = arg; - return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName), parseAsDtype); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const OP_SCOPE_SUFFIX = '__op'; - /** - * Used for wrapping functions that perform math operations on - * Tensors. The function will be wrapped in a named scope that cleans all - * memory usage after the function is done. - */ - function op(f) { - const keys = Object.keys(f); - if (keys.length !== 1) { - throw new Error(`Please provide an object with a single key ` + - `(operation name) mapping to a function. Got an object with ` + - `${keys.length} keys.`); - } - let opName = keys[0]; - const fn = f[opName]; - // Strip the underscore from the end of the function name. - if (opName.endsWith('_')) { - opName = opName.substring(0, opName.length - 1); - } - // add an __op suffix to distinguish ops from kernels in tf.profile - opName = opName + OP_SCOPE_SUFFIX; - // tslint:disable-next-line:no-any - const f2 = (...args) => { - ENGINE.startScope(opName); - try { - const result = fn(...args); - if (result instanceof Promise) { - console.error('Cannot return a Promise inside of tidy.'); - } - ENGINE.endScope(result); - return result; - } - catch (ex) { - ENGINE.endScope(null); - throw ex; - } - }; - Object.defineProperty(f2, 'name', { value: opName, configurable: true }); - // tslint:disable-next-line:no-any - return f2; - } - - /** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Converts two real numbers to a complex number. - * - * Given a tensor `real` representing the real part of a complex number, and a - * tensor `imag` representing the imaginary part of a complex number, this - * operation returns complex numbers elementwise of the form [r0, i0, r1, i1], - * where r represents the real part and i represents the imag part. - * - * The input tensors real and imag must have the same shape. - * - * ```js - * const real = tf.tensor1d([2.25, 3.25]); - * const imag = tf.tensor1d([4.75, 5.75]); - * const complex = tf.complex(real, imag); - * - * complex.print(); - * ``` - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - function complex_(real, imag) { - const $real = convertToTensor(real, 'real', 'complex'); - const $imag = convertToTensor(imag, 'imag', 'complex'); - assertShapesMatch($real.shape, $imag.shape, `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` + - `must match in call to tf.complex().`); - const forward = (backend) => { - return backend.complex($real, $imag); - }; - const inputs = { real: $real, imag: $imag }; - return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Complex); - } - const complex = op({ complex_ }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** This is shared code across all tensor creation methods. */ - function makeTensor(values, shape, inferredShape, dtype) { - if (dtype == null) { - dtype = inferDtype(values); - } - if (dtype === 'complex64') { - throw new Error(`Cannot construct a complex64 tensor directly. ` + - `Please use tf.complex(real, imag).`); - } - if (!isTypedArray(values) && !Array.isArray(values) && - typeof values !== 'number' && typeof values !== 'boolean' && - typeof values !== 'string') { - throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + - 'an array of numbers/booleans/strings, or a TypedArray'); - } - if (shape != null) { - assertNonNegativeIntegerDimensions(shape); - const providedSize = sizeFromShape(shape); - const inferredSize = sizeFromShape(inferredShape); - assert(providedSize === inferredSize, () => `Based on the provided shape, [${shape}], the tensor should have ` + - `${providedSize} values but has ${inferredSize}`); - for (let i = 0; i < inferredShape.length; ++i) { - const inferred = inferredShape[i]; - const flatDimsDontMatch = i === inferredShape.length - 1 ? - inferred !== sizeFromShape(shape.slice(i)) : - true; - assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, () => `Error creating a new Tensor. Inferred shape ` + - `(${inferredShape}) does not match the provided ` + - `shape (${shape}). `); - } - } - if (!isTypedArray(values) && !Array.isArray(values)) { - values = [values]; - } - shape = shape || inferredShape; - values = dtype !== 'string' ? - toTypedArray(values, dtype) : - flatten(values, [], true); - return ENGINE.makeTensor(values, shape, dtype); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Creates a `tf.Tensor` with the provided values, shape and dtype. - * - * ```js - * // Pass an array of values to create a vector. - * tf.tensor([1, 2, 3, 4]).print(); - * ``` - * - * ```js - * // Pass a nested array of values to make a matrix or a higher - * // dimensional tensor. - * tf.tensor([[1, 2], [3, 4]]).print(); - * ``` - * - * ```js - * // Pass a flat array and specify a shape yourself. - * tf.tensor([1, 2, 3, 4], [2, 2]).print(); - * ``` - * - * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. If the values are strings, - * they will be encoded as utf-8 and kept as `Uint8Array[]`. - * @param shape The shape of the tensor. Optional. If not provided, - * it is inferred from `values`. - * @param dtype The data type. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - function tensor(values, shape, dtype) { - const inferredShape = inferShape(values, dtype); - return makeTensor(values, shape, inferredShape, dtype); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /* Type definitions for exporting and importing of models. */ - /** - * A map from Tensor dtype to number of bytes per element of the Tensor. - */ - const DTYPE_VALUE_SIZE_MAP = { - 'float32': 4, - 'float16': 2, - 'int32': 4, - 'uint16': 2, - 'uint8': 1, - 'bool': 1, - 'complex64': 8 - }; - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** Number of bytes reserved for the length of the string. (32bit integer). */ - const NUM_BYTES_STRING_LENGTH = 4; - /** - * Encode a map from names to weight values as an ArrayBuffer, along with an - * `Array` of `WeightsManifestEntry` as specification of the encoded weights. - * - * This function does not perform sharding. - * - * This function is the reverse of `decodeWeights`. - * - * @param tensors A map ("dict") from names to tensors. - * @param group Group to which the weights belong (optional). - * @returns A `Promise` of - * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s - * concatenated. - * - An `Array` of `WeightManifestEntry`s, carrying information including - * tensor names, `dtype`s and shapes. - * @throws Error: on unsupported tensor `dtype`. - */ - async function encodeWeights(tensors, group) { - // TODO(adarob, cais): Support quantization. - const specs = []; - const dataPromises = []; - const names = Array.isArray(tensors) ? - tensors.map(tensor => tensor.name) : - Object.keys(tensors); - for (let i = 0; i < names.length; ++i) { - const name = names[i]; - const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name]; - if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && - t.dtype !== 'string' && t.dtype !== 'complex64') { - throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`); - } - const spec = { name, shape: t.shape, dtype: t.dtype }; - if (t.dtype === 'string') { - const utf8bytes = new Promise(async (resolve) => { - const vals = await t.bytes(); - const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) + - NUM_BYTES_STRING_LENGTH * vals.length; - const bytes = new Uint8Array(totalNumBytes); - let offset = 0; - for (let i = 0; i < vals.length; i++) { - const val = vals[i]; - const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer); - bytes.set(bytesOfLength, offset); - offset += NUM_BYTES_STRING_LENGTH; - bytes.set(val, offset); - offset += val.length; - } - resolve(bytes); - }); - dataPromises.push(utf8bytes); - } - else { - dataPromises.push(t.data()); - } - if (group != null) { - spec.group = group; - } - specs.push(spec); - } - const tensorValues = await Promise.all(dataPromises); - return { data: concatenateTypedArrays(tensorValues), specs }; - } - /** - * Decode flat ArrayBuffer as weights. - * - * This function does not handle sharding. - * - * This function is the reverse of `encodeWeights`. - * - * @param buffer A flat ArrayBuffer carrying the binary values of the tensors - * concatenated in the order specified in `specs`. - * @param specs Specifications of the names, dtypes and shapes of the tensors - * whose value are encoded by `buffer`. - * @return A map from tensor name to tensor value, with the names corresponding - * to names in `specs`. - * @throws Error, if any of the tensors has unsupported dtype. - */ - function decodeWeights(buffer, specs) { - // TODO(adarob, cais): Support quantization. - const out = {}; - let float16Decode; - let offset = 0; - for (const spec of specs) { - const name = spec.name; - const dtype = spec.dtype; - const shape = spec.shape; - const size = sizeFromShape(shape); - let values; - if ('quantization' in spec) { - const quantization = spec.quantization; - if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { - if (!('min' in quantization && 'scale' in quantization)) { - throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} ` + - `doesn't have corresponding metadata min and scale.`); - } - } - else if (quantization.dtype === 'float16') { - if (dtype !== 'float32') { - throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} ` + - `which only supports weights of type float32 not ${dtype}.`); - } - } - else { - throw new Error(`Weight ${spec.name} has unknown ` + - `quantization dtype ${quantization.dtype}. ` + - `Supported quantization dtypes are: ` + - `'uint8', 'uint16', and 'float16'.`); - } - const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; - const byteBuffer = buffer.slice(offset, offset + size * quantizationSizeFactor); - const quantizedArray = (quantization.dtype === 'uint8') ? - new Uint8Array(byteBuffer) : - new Uint16Array(byteBuffer); - if (dtype === 'float32') { - if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { - values = new Float32Array(quantizedArray.length); - for (let i = 0; i < quantizedArray.length; i++) { - const v = quantizedArray[i]; - values[i] = v * quantization.scale + quantization.min; - } - } - else if (quantization.dtype === 'float16') { - if (float16Decode === undefined) { - float16Decode = getFloat16Decoder(); - } - values = float16Decode(quantizedArray); - } - else { - throw new Error(`Unsupported quantization type ${quantization.dtype} ` + - `for weight type float32.`); - } - } - else if (dtype === 'int32') { - if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { - throw new Error(`Unsupported quantization type ${quantization.dtype} ` + - `for weight type int32.`); - } - values = new Int32Array(quantizedArray.length); - for (let i = 0; i < quantizedArray.length; i++) { - const v = quantizedArray[i]; - values[i] = Math.round(v * quantization.scale + quantization.min); - } - } - else { - throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); - } - offset += size * quantizationSizeFactor; - } - else if (dtype === 'string') { - const size = sizeFromShape(spec.shape); - values = []; - for (let i = 0; i < size; i++) { - const byteLength = new Uint32Array(buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; - offset += NUM_BYTES_STRING_LENGTH; - const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); - values.push(bytes); - offset += byteLength; - } - } - else { - const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; - const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); - if (dtype === 'float32') { - values = new Float32Array(byteBuffer); - } - else if (dtype === 'int32') { - values = new Int32Array(byteBuffer); - } - else if (dtype === 'bool') { - values = new Uint8Array(byteBuffer); - } - else if (dtype === 'complex64') { - values = new Float32Array(byteBuffer); - const real = new Float32Array(values.length / 2); - const image = new Float32Array(values.length / 2); - for (let i = 0; i < real.length; i++) { - real[i] = values[i * 2]; - image[i] = values[i * 2 + 1]; - } - const realTensor = tensor(real, shape, 'float32'); - const imageTensor = tensor(image, shape, 'float32'); - out[name] = complex(realTensor, imageTensor); - realTensor.dispose(); - imageTensor.dispose(); - } - else { - throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); - } - offset += size * dtypeFactor; - } - if (dtype !== 'complex64') { - out[name] = tensor(values, shape, dtype); - } - } - return out; - } - /** - * Concatenate TypedArrays into an ArrayBuffer. - */ - function concatenateTypedArrays(xs) { - // TODO(adarob, cais): Support quantization. - if (xs === null) { - throw new Error(`Invalid input value: ${JSON.stringify(xs)}`); - } - let totalByteLength = 0; - // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer' - // can have a different byte length from that of the `TypedArray` itself, - // for example, when the `TypedArray` is created from an offset in an - // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match - // the `TypedArray` in byte length. If an element of `xs` does not show - // this property, a new `TypedArray` that satisfy this property will be - // constructed and pushed into `normalizedXs`. - const normalizedXs = []; - xs.forEach((x) => { - totalByteLength += x.byteLength; - // tslint:disable:no-any - normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : - new x.constructor(x)); - if (!(x instanceof Float32Array || x instanceof Int32Array || - x instanceof Uint8Array)) { - throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`); - } - // tslint:enable:no-any - }); - const y = new Uint8Array(totalByteLength); - let offset = 0; - normalizedXs.forEach((x) => { - y.set(new Uint8Array(x.buffer), offset); - offset += x.byteLength; - }); - return y.buffer; - } - // Use Buffer on Node.js instead of Blob/atob/btoa - const useNodeBuffer = typeof Buffer !== 'undefined' && - (typeof Blob === 'undefined' || typeof atob === 'undefined' || - typeof btoa === 'undefined'); - /** - * Calculate the byte length of a JavaScript string. - * - * Note that a JavaScript string can contain wide characters, therefore the - * length of the string is not necessarily equal to the byte length. - * - * @param str Input string. - * @returns Byte length. - */ - function stringByteLength(str) { - if (useNodeBuffer) { - return Buffer.byteLength(str); - } - return new Blob([str]).size; - } - /** - * Encode an ArrayBuffer as a base64 encoded string. - * - * @param buffer `ArrayBuffer` to be converted. - * @returns A string that base64-encodes `buffer`. - */ - function arrayBufferToBase64String(buffer) { - if (useNodeBuffer) { - return Buffer.from(buffer).toString('base64'); - } - const buf = new Uint8Array(buffer); - let s = ''; - for (let i = 0, l = buf.length; i < l; i++) { - s += String.fromCharCode(buf[i]); - } - return btoa(s); - } - /** - * Decode a base64 string as an ArrayBuffer. - * - * @param str Base64 string. - * @returns Decoded `ArrayBuffer`. - */ - function base64StringToArrayBuffer(str) { - if (useNodeBuffer) { - const buf = Buffer.from(str, 'base64'); - return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); - } - const s = atob(str); - const buffer = new Uint8Array(s.length); - for (let i = 0; i < s.length; ++i) { - buffer.set([s.charCodeAt(i)], i); - } - return buffer.buffer; - } - /** - * Concatenate a number of ArrayBuffers into one. - * - * @param buffers A number of array buffers to concatenate. - * @returns Result of concatenating `buffers` in order. - */ - function concatenateArrayBuffers(buffers) { - if (buffers.length === 1) { - return buffers[0]; - } - let totalByteLength = 0; - buffers.forEach((buffer) => { - totalByteLength += buffer.byteLength; - }); - const temp = new Uint8Array(totalByteLength); - let offset = 0; - buffers.forEach((buffer) => { - temp.set(new Uint8Array(buffer), offset); - offset += buffer.byteLength; - }); - return temp.buffer; - } - /** - * Get the basename of a path. - * - * Behaves in a way analogous to Linux's basename command. - * - * @param path - */ - function basename(path) { - const SEPARATOR = '/'; - path = path.trim(); - while (path.endsWith(SEPARATOR)) { - path = path.slice(0, path.length - 1); - } - const items = path.split(SEPARATOR); - return items[items.length - 1]; - } - /** - * Populate ModelArtifactsInfo fields for a model with JSON topology. - * @param modelArtifacts - * @returns A ModelArtifactsInfo object. - */ - function getModelArtifactsInfoForJSON(modelArtifacts) { - if (modelArtifacts.modelTopology instanceof ArrayBuffer) { - throw new Error('Expected JSON model topology, received ArrayBuffer.'); - } - return { - dateSaved: new Date(), - modelTopologyType: 'JSON', - modelTopologyBytes: modelArtifacts.modelTopology == null ? - 0 : - stringByteLength(JSON.stringify(modelArtifacts.modelTopology)), - weightSpecsBytes: modelArtifacts.weightSpecs == null ? - 0 : - stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), - weightDataBytes: modelArtifacts.weightData == null ? - 0 : - modelArtifacts.weightData.byteLength, - }; - } - /** - * Computes mantisa table for casting Float16 to Float32 - * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf - * - * @returns Uint32Array, 2048 mantissa lookup values. - */ - function computeFloat16MantisaTable() { - const convertMantissa = (i) => { - let m = i << 13; - let e = 0; - while ((m & 0x00800000) === 0) { - e -= 0x00800000; - m <<= 1; - } - m &= ~0x00800000; - e += 0x38800000; - return m | e; - }; - const mantisaTable = new Uint32Array(2048); - mantisaTable[0] = 0; - for (let i = 1; i < 1024; i++) { - mantisaTable[i] = convertMantissa(i); - } - for (let i = 1024; i < 2048; i++) { - mantisaTable[i] = 0x38000000 + ((i - 1024) << 13); - } - return mantisaTable; - } - /** - * Computes exponent table for casting Float16 to Float32 - * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf - * - * @returns Uint32Array, 64 exponent lookup values. - */ - function computeFloat16ExponentTable() { - const exponentTable = new Uint32Array(64); - exponentTable[0] = 0; - exponentTable[31] = 0x47800000; - exponentTable[32] = 0x80000000; - exponentTable[63] = 0xc7800000; - for (let i = 1; i < 31; i++) { - exponentTable[i] = i << 23; - } - for (let i = 33; i < 63; i++) { - exponentTable[i] = 0x80000000 + ((i - 32) << 23); - } - return exponentTable; - } - /** - * Computes offset table for casting Float16 to Float32 - * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf - * - * @returns Uint32Array, 6d offset values. - */ - function computeFloat16OffsetTable() { - const offsetTable = new Uint32Array(64); - for (let i = 0; i < 64; i++) { - offsetTable[i] = 1024; - } - offsetTable[0] = offsetTable[32] = 0; - return offsetTable; - } - /** - * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values - * to a Float32Array. - * - * @returns Function (buffer: Uint16Array) => Float32Array which decodes - * the Uint16Array of Float16 bytes to a Float32Array. - */ - function getFloat16Decoder() { - // Algorithm is based off of - // http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf - // Cache lookup tables - const mantisaTable = computeFloat16MantisaTable(); - const exponentTable = computeFloat16ExponentTable(); - const offsetTable = computeFloat16OffsetTable(); - return (quantizedArray) => { - const buffer = new ArrayBuffer(4 * quantizedArray.length); - const bufferUint32View = new Uint32Array(buffer); - for (let index = 0; index < quantizedArray.length; index++) { - const float16Bits = quantizedArray[index]; - const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] + - exponentTable[float16Bits >> 10]; - bufferUint32View[index] = float32Bits; - } - return new Float32Array(buffer); - }; - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - class IORouterRegistry { - constructor() { - this.saveRouters = []; - this.loadRouters = []; - } - static getInstance() { - if (IORouterRegistry.instance == null) { - IORouterRegistry.instance = new IORouterRegistry(); - } - return IORouterRegistry.instance; - } - /** - * Register a save-handler router. - * - * @param saveRouter A function that maps a URL-like string onto an instance - * of `IOHandler` with the `save` method defined or `null`. - */ - static registerSaveRouter(saveRouter) { - IORouterRegistry.getInstance().saveRouters.push(saveRouter); - } - /** - * Register a load-handler router. - * - * @param loadRouter A function that maps a URL-like string onto an instance - * of `IOHandler` with the `load` method defined or `null`. - */ - static registerLoadRouter(loadRouter) { - IORouterRegistry.getInstance().loadRouters.push(loadRouter); - } - /** - * Look up IOHandler for saving, given a URL-like string. - * - * @param url - * @returns If only one match is found, an instance of IOHandler with the - * `save` method defined. If no match is found, `null`. - * @throws Error, if more than one match is found. - */ - static getSaveHandlers(url) { - return IORouterRegistry.getHandlers(url, 'save'); - } - /** - * Look up IOHandler for loading, given a URL-like string. - * - * @param url - * @param loadOptions Optional, custom load options. - * @returns All valid handlers for `url`, given the currently registered - * handler routers. - */ - static getLoadHandlers(url, loadOptions) { - return IORouterRegistry.getHandlers(url, 'load', loadOptions); - } - static getHandlers(url, handlerType, loadOptions) { - const validHandlers = []; - const routers = handlerType === 'load' ? - IORouterRegistry.getInstance().loadRouters : - IORouterRegistry.getInstance().saveRouters; - routers.forEach(router => { - const handler = router(url, loadOptions); - if (handler !== null) { - validHandlers.push(handler); - } - }); - return validHandlers; - } - } - const registerSaveRouter = (loudRouter) => IORouterRegistry.registerSaveRouter(loudRouter); - const registerLoadRouter = (loudRouter) => IORouterRegistry.registerLoadRouter(loudRouter); - const getSaveHandlers = (url) => IORouterRegistry.getSaveHandlers(url); - const getLoadHandlers = (url, loadOptions) => IORouterRegistry.getLoadHandlers(url, loadOptions); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const DATABASE_NAME = 'tensorflowjs'; - const DATABASE_VERSION = 1; - // Model data and ModelArtifactsInfo (metadata) are stored in two separate - // stores for efficient access of the list of stored models and their metadata. - // 1. The object store for model data: topology, weights and weight manifests. - const MODEL_STORE_NAME = 'models_store'; - // 2. The object store for ModelArtifactsInfo, including meta-information such - // as the type of topology (JSON vs binary), byte size of the topology, byte - // size of the weights, etc. - const INFO_STORE_NAME = 'model_info_store'; - /** - * Delete the entire database for tensorflow.js, including the models store. - */ - async function deleteDatabase() { - const idbFactory = getIndexedDBFactory(); - return new Promise((resolve, reject) => { - const deleteRequest = idbFactory.deleteDatabase(DATABASE_NAME); - deleteRequest.onsuccess = () => resolve(); - deleteRequest.onerror = error => reject(error); - }); - } - function getIndexedDBFactory() { - if (!env().getBool('IS_BROWSER')) { - // TODO(cais): Add more info about what IOHandler subtypes are available. - // Maybe point to a doc page on the web and/or automatically determine - // the available IOHandlers and print them in the error message. - throw new Error('Failed to obtain IndexedDB factory because the current environment' + - 'is not a web browser.'); - } - // tslint:disable-next-line:no-any - const theWindow = typeof window === 'undefined' ? self : window; - const factory = theWindow.indexedDB || theWindow.mozIndexedDB || - theWindow.webkitIndexedDB || theWindow.msIndexedDB || - theWindow.shimIndexedDB; - if (factory == null) { - throw new Error('The current browser does not appear to support IndexedDB.'); - } - return factory; - } - function setUpDatabase(openRequest) { - const db = openRequest.result; - db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' }); - db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' }); - } - /** - * IOHandler subclass: Browser IndexedDB. - * - * See the doc string of `browserIndexedDB` for more details. - */ - class BrowserIndexedDB { - constructor(modelPath) { - this.indexedDB = getIndexedDBFactory(); - if (modelPath == null || !modelPath) { - throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.'); - } - this.modelPath = modelPath; - } - async save(modelArtifacts) { - // TODO(cais): Support saving GraphDef models. - if (modelArtifacts.modelTopology instanceof ArrayBuffer) { - throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + - 'in binary formats yet.'); - } - return this.databaseAction(this.modelPath, modelArtifacts); - } - async load() { - return this.databaseAction(this.modelPath); - } - /** - * Perform database action to put model artifacts into or read model artifacts - * from IndexedDB object store. - * - * Whether the action is put or get depends on whether `modelArtifacts` is - * specified. If it is specified, the action will be put; otherwise the action - * will be get. - * - * @param modelPath A unique string path for the model. - * @param modelArtifacts If specified, it will be the model artifacts to be - * stored in IndexedDB. - * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise` - * of `ModelArtifacts`, if the action is get. - */ - databaseAction(modelPath, modelArtifacts) { - return new Promise((resolve, reject) => { - const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); - openRequest.onupgradeneeded = () => setUpDatabase(openRequest); - openRequest.onsuccess = () => { - const db = openRequest.result; - if (modelArtifacts == null) { - // Read model out from object store. - const modelTx = db.transaction(MODEL_STORE_NAME, 'readonly'); - const modelStore = modelTx.objectStore(MODEL_STORE_NAME); - const getRequest = modelStore.get(this.modelPath); - getRequest.onsuccess = () => { - if (getRequest.result == null) { - db.close(); - return reject(new Error(`Cannot find model with path '${this.modelPath}' ` + - `in IndexedDB.`)); - } - else { - resolve(getRequest.result.modelArtifacts); - } - }; - getRequest.onerror = error => { - db.close(); - return reject(getRequest.error); - }; - modelTx.oncomplete = () => db.close(); - } - else { - // Put model into object store. - const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); - // First, put ModelArtifactsInfo into info store. - const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite'); - let infoStore = infoTx.objectStore(INFO_STORE_NAME); - const putInfoRequest = infoStore.put({ modelPath: this.modelPath, modelArtifactsInfo }); - let modelTx; - putInfoRequest.onsuccess = () => { - // Second, put model data into model store. - modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite'); - const modelStore = modelTx.objectStore(MODEL_STORE_NAME); - const putModelRequest = modelStore.put({ - modelPath: this.modelPath, - modelArtifacts, - modelArtifactsInfo - }); - putModelRequest.onsuccess = () => resolve({ modelArtifactsInfo }); - putModelRequest.onerror = error => { - // If the put-model request fails, roll back the info entry as - // well. - infoStore = infoTx.objectStore(INFO_STORE_NAME); - const deleteInfoRequest = infoStore.delete(this.modelPath); - deleteInfoRequest.onsuccess = () => { - db.close(); - return reject(putModelRequest.error); - }; - deleteInfoRequest.onerror = error => { - db.close(); - return reject(putModelRequest.error); - }; - }; - }; - putInfoRequest.onerror = error => { - db.close(); - return reject(putInfoRequest.error); - }; - infoTx.oncomplete = () => { - if (modelTx == null) { - db.close(); - } - else { - modelTx.oncomplete = () => db.close(); - } - }; - } - }; - openRequest.onerror = error => reject(openRequest.error); - }); - } - } - BrowserIndexedDB.URL_SCHEME = 'indexeddb://'; - const indexedDBRouter = (url) => { - if (!env().getBool('IS_BROWSER')) { - return null; - } - else { - if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { - return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length)); - } - else { - return null; - } - } - }; - IORouterRegistry.registerSaveRouter(indexedDBRouter); - IORouterRegistry.registerLoadRouter(indexedDBRouter); - /** - * Creates a browser IndexedDB IOHandler for saving and loading models. - * - * ```js - * const model = tf.sequential(); - * model.add( - * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); - * - * const saveResult = await model.save('indexeddb://MyModel')); - * console.log(saveResult); - * ``` - * - * @param modelPath A unique identifier for the model to be saved. Must be a - * non-empty string. - * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`), - * which can be used with, e.g., `tf.Model.save`. - */ - function browserIndexedDB(modelPath) { - return new BrowserIndexedDB(modelPath); - } - function maybeStripScheme(key) { - return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? - key.slice(BrowserIndexedDB.URL_SCHEME.length) : - key; - } - class BrowserIndexedDBManager { - constructor() { - this.indexedDB = getIndexedDBFactory(); - } - async listModels() { - return new Promise((resolve, reject) => { - const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); - openRequest.onupgradeneeded = () => setUpDatabase(openRequest); - openRequest.onsuccess = () => { - const db = openRequest.result; - const tx = db.transaction(INFO_STORE_NAME, 'readonly'); - const store = tx.objectStore(INFO_STORE_NAME); - // tslint:disable:max-line-length - // Need to cast `store` as `any` here because TypeScript's DOM - // library does not have the `getAll()` method even though the - // method is supported in the latest version of most mainstream - // browsers: - // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll - // tslint:enable:max-line-length - // tslint:disable-next-line:no-any - const getAllInfoRequest = store.getAll(); - getAllInfoRequest.onsuccess = () => { - const out = {}; - for (const item of getAllInfoRequest.result) { - out[item.modelPath] = item.modelArtifactsInfo; - } - resolve(out); - }; - getAllInfoRequest.onerror = error => { - db.close(); - return reject(getAllInfoRequest.error); - }; - tx.oncomplete = () => db.close(); - }; - openRequest.onerror = error => reject(openRequest.error); - }); - } - async removeModel(path) { - path = maybeStripScheme(path); - return new Promise((resolve, reject) => { - const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); - openRequest.onupgradeneeded = () => setUpDatabase(openRequest); - openRequest.onsuccess = () => { - const db = openRequest.result; - const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite'); - const infoStore = infoTx.objectStore(INFO_STORE_NAME); - const getInfoRequest = infoStore.get(path); - let modelTx; - getInfoRequest.onsuccess = () => { - if (getInfoRequest.result == null) { - db.close(); - return reject(new Error(`Cannot find model with path '${path}' ` + - `in IndexedDB.`)); - } - else { - // First, delete the entry in the info store. - const deleteInfoRequest = infoStore.delete(path); - const deleteModelData = () => { - // Second, delete the entry in the model store. - modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite'); - const modelStore = modelTx.objectStore(MODEL_STORE_NAME); - const deleteModelRequest = modelStore.delete(path); - deleteModelRequest.onsuccess = () => resolve(getInfoRequest.result.modelArtifactsInfo); - deleteModelRequest.onerror = error => reject(getInfoRequest.error); - }; - // Proceed with deleting model data regardless of whether deletion - // of info data succeeds or not. - deleteInfoRequest.onsuccess = deleteModelData; - deleteInfoRequest.onerror = error => { - deleteModelData(); - db.close(); - return reject(getInfoRequest.error); - }; - } - }; - getInfoRequest.onerror = error => { - db.close(); - return reject(getInfoRequest.error); - }; - infoTx.oncomplete = () => { - if (modelTx == null) { - db.close(); - } - else { - modelTx.oncomplete = () => db.close(); - } - }; - }; - openRequest.onerror = error => reject(openRequest.error); - }); - } - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const PATH_SEPARATOR = '/'; - const PATH_PREFIX = 'tensorflowjs_models'; - const INFO_SUFFIX = 'info'; - const MODEL_TOPOLOGY_SUFFIX = 'model_topology'; - const WEIGHT_SPECS_SUFFIX = 'weight_specs'; - const WEIGHT_DATA_SUFFIX = 'weight_data'; - const MODEL_METADATA_SUFFIX = 'model_metadata'; - /** - * Purge all tensorflow.js-saved model artifacts from local storage. - * - * @returns Paths of the models purged. - */ - function purgeLocalStorageArtifacts() { - if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || - typeof window.localStorage === 'undefined') { - throw new Error('purgeLocalStorageModels() cannot proceed because local storage is ' + - 'unavailable in the current environment.'); - } - const LS = window.localStorage; - const purgedModelPaths = []; - for (let i = 0; i < LS.length; ++i) { - const key = LS.key(i); - const prefix = PATH_PREFIX + PATH_SEPARATOR; - if (key.startsWith(prefix) && key.length > prefix.length) { - LS.removeItem(key); - const modelName = getModelPathFromKey(key); - if (purgedModelPaths.indexOf(modelName) === -1) { - purgedModelPaths.push(modelName); - } - } - } - return purgedModelPaths; - } - function getModelKeys(path) { - return { - info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), - topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), - weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), - weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), - modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) - }; - } - /** - * Get model path from a local-storage key. - * - * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1' - * - * @param key - */ - function getModelPathFromKey(key) { - const items = key.split(PATH_SEPARATOR); - if (items.length < 3) { - throw new Error(`Invalid key format: ${key}`); - } - return items.slice(1, items.length - 1).join(PATH_SEPARATOR); - } - function maybeStripScheme$1(key) { - return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? - key.slice(BrowserLocalStorage.URL_SCHEME.length) : - key; - } - /** - * IOHandler subclass: Browser Local Storage. - * - * See the doc string to `browserLocalStorage` for more details. - */ - class BrowserLocalStorage { - constructor(modelPath) { - if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || - typeof window.localStorage === 'undefined') { - // TODO(cais): Add more info about what IOHandler subtypes are - // available. - // Maybe point to a doc page on the web and/or automatically determine - // the available IOHandlers and print them in the error message. - throw new Error('The current environment does not support local storage.'); - } - this.LS = window.localStorage; - if (modelPath == null || !modelPath) { - throw new Error('For local storage, modelPath must not be null, undefined or empty.'); - } - this.modelPath = modelPath; - this.keys = getModelKeys(this.modelPath); - } - /** - * Save model artifacts to browser local storage. - * - * See the documentation to `browserLocalStorage` for details on the saved - * artifacts. - * - * @param modelArtifacts The model artifacts to be stored. - * @returns An instance of SaveResult. - */ - async save(modelArtifacts) { - if (modelArtifacts.modelTopology instanceof ArrayBuffer) { - throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + - 'in binary formats yet.'); - } - else { - const topology = JSON.stringify(modelArtifacts.modelTopology); - const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); - const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); - try { - this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); - this.LS.setItem(this.keys.topology, topology); - this.LS.setItem(this.keys.weightSpecs, weightSpecs); - this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData)); - this.LS.setItem(this.keys.modelMetadata, JSON.stringify({ - format: modelArtifacts.format, - generatedBy: modelArtifacts.generatedBy, - convertedBy: modelArtifacts.convertedBy, - userDefinedMetadata: modelArtifacts.userDefinedMetadata - })); - return { modelArtifactsInfo }; - } - catch (err) { - // If saving failed, clean up all items saved so far. - this.LS.removeItem(this.keys.info); - this.LS.removeItem(this.keys.topology); - this.LS.removeItem(this.keys.weightSpecs); - this.LS.removeItem(this.keys.weightData); - this.LS.removeItem(this.keys.modelMetadata); - throw new Error(`Failed to save model '${this.modelPath}' to local storage: ` + - `size quota being exceeded is a possible cause of this failure: ` + - `modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` + - `weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` + - `weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`); - } - } - } - /** - * Load a model from local storage. - * - * See the documentation to `browserLocalStorage` for details on the saved - * artifacts. - * - * @returns The loaded model (if loading succeeds). - */ - async load() { - const info = JSON.parse(this.LS.getItem(this.keys.info)); - if (info == null) { - throw new Error(`In local storage, there is no model with name '${this.modelPath}'`); - } - if (info.modelTopologyType !== 'JSON') { - throw new Error('BrowserLocalStorage does not support loading non-JSON model ' + - 'topology yet.'); - } - const out = {}; - // Load topology. - const topology = JSON.parse(this.LS.getItem(this.keys.topology)); - if (topology == null) { - throw new Error(`In local storage, the topology of model '${this.modelPath}' ` + - `is missing.`); - } - out.modelTopology = topology; - // Load weight specs. - const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); - if (weightSpecs == null) { - throw new Error(`In local storage, the weight specs of model '${this.modelPath}' ` + - `are missing.`); - } - out.weightSpecs = weightSpecs; - // Load meta-data fields. - const metadataString = this.LS.getItem(this.keys.modelMetadata); - if (metadataString != null) { - const metadata = JSON.parse(metadataString); - out.format = metadata['format']; - out.generatedBy = metadata['generatedBy']; - out.convertedBy = metadata['convertedBy']; - out.userDefinedMetadata = metadata['userDefinedMetadata']; - } - // Load weight data. - const weightDataBase64 = this.LS.getItem(this.keys.weightData); - if (weightDataBase64 == null) { - throw new Error(`In local storage, the binary weight values of model ` + - `'${this.modelPath}' are missing.`); - } - out.weightData = base64StringToArrayBuffer(weightDataBase64); - return out; - } - } - BrowserLocalStorage.URL_SCHEME = 'localstorage://'; - const localStorageRouter = (url) => { - if (!env().getBool('IS_BROWSER')) { - return null; - } - else { - if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { - return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length)); - } - else { - return null; - } - } - }; - IORouterRegistry.registerSaveRouter(localStorageRouter); - IORouterRegistry.registerLoadRouter(localStorageRouter); - /** - * Factory function for local storage IOHandler. - * - * This `IOHandler` supports both `save` and `load`. - * - * For each model's saved artifacts, four items are saved to local storage. - * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the - * model, such as date saved, type of the topology, size in bytes, etc. - * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras- - * style models, this is a stringized JSON. - * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the - * model, can be used to decode the saved binary weight values (see - * item below). - * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary - * weight values, stored as a base64-encoded string. - * - * Saving may throw an `Error` if the total size of the artifacts exceed the - * browser-specific quota. - * - * @param modelPath A unique identifier for the model to be saved. Must be a - * non-empty string. - * @returns An instance of `IOHandler`, which can be used with, e.g., - * `tf.Model.save`. - */ - function browserLocalStorage(modelPath) { - return new BrowserLocalStorage(modelPath); - } - class BrowserLocalStorageManager { - constructor() { - assert(env().getBool('IS_BROWSER'), () => 'Current environment is not a web browser'); - assert(typeof window === 'undefined' || - typeof window.localStorage !== 'undefined', () => 'Current browser does not appear to support localStorage'); - this.LS = window.localStorage; - } - async listModels() { - const out = {}; - const prefix = PATH_PREFIX + PATH_SEPARATOR; - const suffix = PATH_SEPARATOR + INFO_SUFFIX; - for (let i = 0; i < this.LS.length; ++i) { - const key = this.LS.key(i); - if (key.startsWith(prefix) && key.endsWith(suffix)) { - const modelPath = getModelPathFromKey(key); - out[modelPath] = JSON.parse(this.LS.getItem(key)); - } - } - return out; - } - async removeModel(path) { - path = maybeStripScheme$1(path); - const keys = getModelKeys(path); - if (this.LS.getItem(keys.info) == null) { - throw new Error(`Cannot find model at path '${path}'`); - } - const info = JSON.parse(this.LS.getItem(keys.info)); - this.LS.removeItem(keys.info); - this.LS.removeItem(keys.topology); - this.LS.removeItem(keys.weightSpecs); - this.LS.removeItem(keys.weightData); - return info; - } - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const URL_SCHEME_SUFFIX = '://'; - class ModelStoreManagerRegistry { - constructor() { - this.managers = {}; - } - static getInstance() { - if (ModelStoreManagerRegistry.instance == null) { - ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry(); - } - return ModelStoreManagerRegistry.instance; - } - /** - * Register a save-handler router. - * - * @param saveRouter A function that maps a URL-like string onto an instance - * of `IOHandler` with the `save` method defined or `null`. - */ - static registerManager(scheme, manager) { - assert(scheme != null, () => 'scheme must not be undefined or null.'); - if (scheme.endsWith(URL_SCHEME_SUFFIX)) { - scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX)); - } - assert(scheme.length > 0, () => 'scheme must not be an empty string.'); - const registry = ModelStoreManagerRegistry.getInstance(); - assert(registry.managers[scheme] == null, () => `A model store manager is already registered for scheme '${scheme}'.`); - registry.managers[scheme] = manager; - } - static getManager(scheme) { - const manager = this.getInstance().managers[scheme]; - if (manager == null) { - throw new Error(`Cannot find model manager for scheme '${scheme}'`); - } - return manager; - } - static getSchemes() { - return Object.keys(this.getInstance().managers); - } - } - /** - * Helper method for parsing a URL string into a scheme and a path. - * - * @param url E.g., 'localstorage://my-model' - * @returns A dictionary with two fields: scheme and path. - * Scheme: e.g., 'localstorage' in the example above. - * Path: e.g., 'my-model' in the example above. - */ - function parseURL(url) { - if (url.indexOf(URL_SCHEME_SUFFIX) === -1) { - throw new Error(`The url string provided does not contain a scheme. ` + - `Supported schemes are: ` + - `${ModelStoreManagerRegistry.getSchemes().join(',')}`); - } - return { - scheme: url.split(URL_SCHEME_SUFFIX)[0], - path: url.split(URL_SCHEME_SUFFIX)[1], - }; - } - async function cloneModelInternal(sourceURL, destURL, deleteSource = false) { - assert(sourceURL !== destURL, () => `Old path and new path are the same: '${sourceURL}'`); - const loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL); - assert(loadHandlers.length > 0, () => `Copying failed because no load handler is found for source URL ${sourceURL}.`); - assert(loadHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) ` + - `load handlers for source URL ${sourceURL}.`); - const loadHandler = loadHandlers[0]; - const saveHandlers = IORouterRegistry.getSaveHandlers(destURL); - assert(saveHandlers.length > 0, () => `Copying failed because no save handler is found for destination ` + - `URL ${destURL}.`); - assert(saveHandlers.length < 2, () => `Copying failed because more than one (${loadHandlers.length}) ` + - `save handlers for destination URL ${destURL}.`); - const saveHandler = saveHandlers[0]; - const sourceScheme = parseURL(sourceURL).scheme; - const sourcePath = parseURL(sourceURL).path; - const sameMedium = sourceScheme === parseURL(sourceURL).scheme; - const modelArtifacts = await loadHandler.load(); - // If moving within the same storage medium, remove the old model as soon as - // the loading is done. Without doing this, it is possible that the combined - // size of the two models will cause the cloning to fail. - if (deleteSource && sameMedium) { - await ModelStoreManagerRegistry.getManager(sourceScheme) - .removeModel(sourcePath); - } - const saveResult = await saveHandler.save(modelArtifacts); - // If moving between mediums, the deletion is done after the save succeeds. - // This guards against the case in which saving to the destination medium - // fails. - if (deleteSource && !sameMedium) { - await ModelStoreManagerRegistry.getManager(sourceScheme) - .removeModel(sourcePath); - } - return saveResult.modelArtifactsInfo; - } - /** - * List all models stored in registered storage mediums. - * - * For a web browser environment, the registered mediums are Local Storage and - * IndexedDB. - * - * ```js - * // First create and save a model. - * const model = tf.sequential(); - * model.add(tf.layers.dense( - * {units: 1, inputShape: [10], activation: 'sigmoid'})); - * await model.save('localstorage://demo/management/model1'); - * - * // Then list existing models. - * console.log(JSON.stringify(await tf.io.listModels())); - * - * // Delete the model. - * await tf.io.removeModel('localstorage://demo/management/model1'); - * - * // List models again. - * console.log(JSON.stringify(await tf.io.listModels())); - * ``` - * - * @returns A `Promise` of a dictionary mapping URLs of existing models to - * their model artifacts info. URLs include medium-specific schemes, e.g., - * 'indexeddb://my/model/1'. Model artifacts info include type of the - * model's topology, byte sizes of the topology, weights, etc. - * - * @doc { - * heading: 'Models', - * subheading: 'Management', - * namespace: 'io', - * ignoreCI: true - * } - */ - async function listModels() { - const schemes = ModelStoreManagerRegistry.getSchemes(); - const out = {}; - for (const scheme of schemes) { - const schemeOut = await ModelStoreManagerRegistry.getManager(scheme).listModels(); - for (const path in schemeOut) { - const url = scheme + URL_SCHEME_SUFFIX + path; - out[url] = schemeOut[path]; - } - } - return out; - } - /** - * Remove a model specified by URL from a reigstered storage medium. - * - * ```js - * // First create and save a model. - * const model = tf.sequential(); - * model.add(tf.layers.dense( - * {units: 1, inputShape: [10], activation: 'sigmoid'})); - * await model.save('localstorage://demo/management/model1'); - * - * // Then list existing models. - * console.log(JSON.stringify(await tf.io.listModels())); - * - * // Delete the model. - * await tf.io.removeModel('localstorage://demo/management/model1'); - * - * // List models again. - * console.log(JSON.stringify(await tf.io.listModels())); - * ``` - * - * @param url A URL to a stored model, with a scheme prefix, e.g., - * 'localstorage://my-model-1', 'indexeddb://my/model/2'. - * @returns ModelArtifactsInfo of the deleted model (if and only if deletion - * is successful). - * @throws Error if deletion fails, e.g., if no model exists at `path`. - * - * @doc { - * heading: 'Models', - * subheading: 'Management', - * namespace: 'io', - * ignoreCI: true - * } - */ - async function removeModel(url) { - const schemeAndPath = parseURL(url); - const manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme); - return manager.removeModel(schemeAndPath.path); - } - /** - * Copy a model from one URL to another. - * - * This function supports: - * - * 1. Copying within a storage medium, e.g., - * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')` - * 2. Copying between two storage mediums, e.g., - * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')` - * - * ```js - * // First create and save a model. - * const model = tf.sequential(); - * model.add(tf.layers.dense( - * {units: 1, inputShape: [10], activation: 'sigmoid'})); - * await model.save('localstorage://demo/management/model1'); - * - * // Then list existing models. - * console.log(JSON.stringify(await tf.io.listModels())); - * - * // Copy the model, from Local Storage to IndexedDB. - * await tf.io.copyModel( - * 'localstorage://demo/management/model1', - * 'indexeddb://demo/management/model1'); - * - * // List models again. - * console.log(JSON.stringify(await tf.io.listModels())); - * - * // Remove both models. - * await tf.io.removeModel('localstorage://demo/management/model1'); - * await tf.io.removeModel('indexeddb://demo/management/model1'); - * ``` - * - * @param sourceURL Source URL of copying. - * @param destURL Destination URL of copying. - * @returns ModelArtifactsInfo of the copied model (if and only if copying - * is successful). - * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or - * if `oldPath` and `newPath` are identical. - * - * @doc { - * heading: 'Models', - * subheading: 'Management', - * namespace: 'io', - * ignoreCI: true - * } - */ - async function copyModel(sourceURL, destURL) { - const deleteSource = false; - return cloneModelInternal(sourceURL, destURL, deleteSource); - } - /** - * Move a model from one URL to another. - * - * This function supports: - * - * 1. Moving within a storage medium, e.g., - * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')` - * 2. Moving between two storage mediums, e.g., - * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')` - * - * ```js - * // First create and save a model. - * const model = tf.sequential(); - * model.add(tf.layers.dense( - * {units: 1, inputShape: [10], activation: 'sigmoid'})); - * await model.save('localstorage://demo/management/model1'); - * - * // Then list existing models. - * console.log(JSON.stringify(await tf.io.listModels())); - * - * // Move the model, from Local Storage to IndexedDB. - * await tf.io.moveModel( - * 'localstorage://demo/management/model1', - * 'indexeddb://demo/management/model1'); - * - * // List models again. - * console.log(JSON.stringify(await tf.io.listModels())); - * - * // Remove the moved model. - * await tf.io.removeModel('indexeddb://demo/management/model1'); - * ``` - * - * @param sourceURL Source URL of moving. - * @param destURL Destination URL of moving. - * @returns ModelArtifactsInfo of the copied model (if and only if copying - * is successful). - * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or - * if `oldPath` and `newPath` are identical. - * - * @doc { - * heading: 'Models', - * subheading: 'Management', - * namespace: 'io', - * ignoreCI: true - * } - */ - async function moveModel(sourceURL, destURL) { - const deleteSource = true; - return cloneModelInternal(sourceURL, destURL, deleteSource); - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - class PlatformBrowser { - fetch(path, init) { - return fetch(path, init); - } - now() { - return performance.now(); - } - encode(text, encoding) { - if (encoding !== 'utf-8' && encoding !== 'utf8') { - throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`); - } - if (this.textEncoder == null) { - this.textEncoder = new TextEncoder(); - } - return this.textEncoder.encode(text); - } - decode(bytes, encoding) { - return new TextDecoder(encoding).decode(bytes); - } - } - if (env().get('IS_BROWSER')) { - env().setPlatform('browser', new PlatformBrowser()); - // Register LocalStorage IOHandler - try { - ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager()); - } - catch (err) { - } - // Register IndexedDB IOHandler - try { - ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager()); - } - catch (err) { - } - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - // We are wrapping this within an object so it can be stubbed by Jasmine. - const getNodeFetch = { - // tslint:disable-next-line:no-require-imports - importFetch: () => require('node-fetch') - }; - let systemFetch; - // These getters and setters are for testing so we don't export a mutable - // variable. - function resetSystemFetch() { - systemFetch = null; - } - function setSystemFetch(fetchFn) { - systemFetch = fetchFn; - } - function getSystemFetch() { - return systemFetch; - } - class PlatformNode { - constructor() { - // tslint:disable-next-line:no-require-imports - this.util = require('util'); - // According to the spec, the built-in encoder can do only UTF-8 encoding. - // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder - this.textEncoder = new this.util.TextEncoder(); - } - fetch(path, requestInits) { - if (env().global.fetch != null) { - return env().global.fetch(path, requestInits); - } - if (systemFetch == null) { - systemFetch = getNodeFetch.importFetch(); - } - return systemFetch(path, requestInits); - } - now() { - const time = process.hrtime(); - return time[0] * 1000 + time[1] / 1000000; - } - encode(text, encoding) { - if (encoding !== 'utf-8' && encoding !== 'utf8') { - throw new Error(`Node built-in encoder only supports utf-8, but got ${encoding}`); - } - return this.textEncoder.encode(text); - } - decode(bytes, encoding) { - if (bytes.length === 0) { - return ''; - } - return new this.util.TextDecoder(encoding).decode(bytes); - } - } - if (env().get('IS_NODE')) { - env().setPlatform('node', new PlatformNode()); - } - - /** - * @license - * Copyright 2020 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`. - * - * The values are stored in CPU as `TypedArray`. Fill the buffer using - * `buffer.set()`, or by modifying directly `buffer.values`. - * - * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with - * those values. - * - * ```js - * // Create a buffer and set values at particular indices. - * const buffer = tf.buffer([2, 2]); - * buffer.set(3, 0, 0); - * buffer.set(5, 1, 0); - * - * // Convert the buffer back to a tensor. - * buffer.toTensor().print(); - * ``` - * - * @param shape An array of integers defining the output tensor shape. - * @param dtype The dtype of the buffer. Defaults to 'float32'. - * @param values The values of the buffer as `TypedArray`. Defaults to - * zeros. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - function buffer(shape, dtype = 'float32', values) { - dtype = dtype || 'float32'; - assertNonNegativeIntegerDimensions(shape); - return new TensorBuffer(shape, dtype, values); - } - - /** - * @license - * Copyright 2020 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Casts a `tf.Tensor` to a new dtype. - * - * ```js - * const x = tf.tensor1d([1.5, 2.5, 3]); - * tf.cast(x, 'int32').print(); - * ``` - * @param x The input tensor to be casted. - * @param dtype The dtype to cast the input tensor to. - * - * @doc {heading: 'Tensors', subheading: 'Transformations'} - */ - function cast_(x, dtype) { - const $x = convertToTensor(x, 'x', 'cast'); - // Sanity checks. - if (!isValidDtype(dtype)) { - throw new Error(`Failed to cast to unknown dtype ${dtype}`); - } - if (dtype === 'string' && $x.dtype !== 'string' || - dtype !== 'string' && $x.dtype === 'string') { - throw new Error('Only strings can be casted to strings'); - } - const inputs = { x: $x }; - const attrs = { dtype }; - return ENGINE.runKernelFunc(backend => backend.cast($x, dtype), inputs, null /* grad */, Cast, attrs); - } - const cast = op({ cast_ }); - - /** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Creates a new tensor with the same values and shape as the specified - * tensor. - * - * ```js - * const x = tf.tensor([1, 2]); - * - * x.clone().print(); - * ``` - * - * @param x The tensor to clone. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - function clone_(x) { - const $x = convertToTensor(x, 'x', 'clone', null); - const forward = () => ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype); - const inputs = { x: $x }; - // Note this op is called tf.identity in python. Hence the kernel name used - // here. - return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Identity); - } - const clone = op({ clone_ }); - - /** - * @license - * Copyright 2020 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Prints information about the `tf.Tensor` including its data. - * - * ```js - * const verbose = true; - * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose); - * ``` - * @param x The tensor to be printed. - * @param verbose Whether to print verbose information about the ` Tensor`, - * including dtype and size. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - function print(x, verbose = false) { - console.log(x.toString(verbose)); - } - - /** - * @license - * Copyright 2020 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - getOrMakeEngine(); - const opHandler$1 = { - buffer, - cast, - clone, - print - }; - setOpHandler(opHandler$1); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const DEFAULT_FILE_NAME_PREFIX = 'model'; - const DEFAULT_JSON_EXTENSION_NAME = '.json'; - const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin'; - function defer(f) { - return new Promise(resolve => setTimeout(resolve)).then(f); - } - class BrowserDownloads { - constructor(fileNamePrefix) { - if (!env().getBool('IS_BROWSER')) { - // TODO(cais): Provide info on what IOHandlers are available under the - // current environment. - throw new Error('browserDownloads() cannot proceed because the current environment ' + - 'is not a browser.'); - } - if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) { - fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length); - } - if (fileNamePrefix == null || fileNamePrefix.length === 0) { - fileNamePrefix = DEFAULT_FILE_NAME_PREFIX; - } - this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME; - this.weightDataFileName = - fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME; - } - async save(modelArtifacts) { - if (typeof (document) === 'undefined') { - throw new Error('Browser downloads are not supported in ' + - 'this environment since `document` is not present'); - } - const weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], { type: 'application/octet-stream' })); - if (modelArtifacts.modelTopology instanceof ArrayBuffer) { - throw new Error('BrowserDownloads.save() does not support saving model topology ' + - 'in binary formats yet.'); - } - else { - const weightsManifest = [{ - paths: ['./' + this.weightDataFileName], - weights: modelArtifacts.weightSpecs - }]; - const modelTopologyAndWeightManifest = { - modelTopology: modelArtifacts.modelTopology, - format: modelArtifacts.format, - generatedBy: modelArtifacts.generatedBy, - convertedBy: modelArtifacts.convertedBy, - weightsManifest - }; - const modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: 'application/json' })); - // If anchor elements are not provided, create them without attaching them - // to parents, so that the downloaded file names can be controlled. - const jsonAnchor = this.jsonAnchor == null ? document.createElement('a') : - this.jsonAnchor; - jsonAnchor.download = this.modelTopologyFileName; - jsonAnchor.href = modelTopologyAndWeightManifestURL; - // Trigger downloads by evoking a click event on the download anchors. - // When multiple downloads are started synchronously, Firefox will only - // save the last one. - await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click'))); - if (modelArtifacts.weightData != null) { - const weightDataAnchor = this.weightDataAnchor == null ? - document.createElement('a') : - this.weightDataAnchor; - weightDataAnchor.download = this.weightDataFileName; - weightDataAnchor.href = weightsURL; - await defer(() => weightDataAnchor.dispatchEvent(new MouseEvent('click'))); - } - return { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) }; - } - } - } - BrowserDownloads.URL_SCHEME = 'downloads://'; - class BrowserFiles { - constructor(files) { - if (files == null || files.length < 1) { - throw new Error(`When calling browserFiles, at least 1 file is required, ` + - `but received ${files}`); - } - this.files = files; - } - async load() { - const jsonFile = this.files[0]; - const weightFiles = this.files.slice(1); - return new Promise((resolve, reject) => { - const jsonReader = new FileReader(); - jsonReader.onload = (event) => { - // tslint:disable-next-line:no-any - const modelJSON = JSON.parse(event.target.result); - const modelTopology = modelJSON.modelTopology; - if (modelTopology == null) { - reject(new Error(`modelTopology field is missing from file ${jsonFile.name}`)); - return; - } - if (weightFiles.length === 0) { - resolve({ modelTopology }); - } - const weightsManifest = modelJSON.weightsManifest; - if (weightsManifest == null) { - reject(new Error(`weightManifest field is missing from file ${jsonFile.name}`)); - return; - } - let pathToFile; - try { - pathToFile = - this.checkManifestAndWeightFiles(weightsManifest, weightFiles); - } - catch (err) { - reject(err); - return; - } - const weightSpecs = []; - const paths = []; - const perFileBuffers = []; - weightsManifest.forEach(weightsGroup => { - weightsGroup.paths.forEach(path => { - paths.push(path); - perFileBuffers.push(null); - }); - weightSpecs.push(...weightsGroup.weights); - }); - weightsManifest.forEach(weightsGroup => { - weightsGroup.paths.forEach(path => { - const weightFileReader = new FileReader(); - weightFileReader.onload = (event) => { - // tslint:disable-next-line:no-any - const weightData = event.target.result; - const index = paths.indexOf(path); - perFileBuffers[index] = weightData; - if (perFileBuffers.indexOf(null) === -1) { - resolve({ - modelTopology, - weightSpecs, - weightData: concatenateArrayBuffers(perFileBuffers), - format: modelJSON.format, - generatedBy: modelJSON.generatedBy, - convertedBy: modelJSON.convertedBy, - userDefinedMetadata: modelJSON.userDefinedMetadata - }); - } - }; - weightFileReader.onerror = error => reject(`Failed to weights data from file of path '${path}'.`); - weightFileReader.readAsArrayBuffer(pathToFile[path]); - }); - }); - }; - jsonReader.onerror = error => reject(`Failed to read model topology and weights manifest JSON ` + - `from file '${jsonFile.name}'. BrowserFiles supports loading ` + - `Keras-style tf.Model artifacts only.`); - jsonReader.readAsText(jsonFile); - }); - } - /** - * Check the compatibility between weights manifest and weight files. - */ - checkManifestAndWeightFiles(manifest, files) { - const basenames = []; - const fileNames = files.map(file => basename(file.name)); - const pathToFile = {}; - for (const group of manifest) { - group.paths.forEach(path => { - const pathBasename = basename(path); - if (basenames.indexOf(pathBasename) !== -1) { - throw new Error(`Duplicate file basename found in weights manifest: ` + - `'${pathBasename}'`); - } - basenames.push(pathBasename); - if (fileNames.indexOf(pathBasename) === -1) { - throw new Error(`Weight file with basename '${pathBasename}' is not provided.`); - } - else { - pathToFile[path] = files[fileNames.indexOf(pathBasename)]; - } - }); - } - if (basenames.length !== files.length) { - throw new Error(`Mismatch in the number of files in weights manifest ` + - `(${basenames.length}) and the number of weight files provided ` + - `(${files.length}).`); - } - return pathToFile; - } - } - const browserDownloadsRouter = (url) => { - if (!env().getBool('IS_BROWSER')) { - return null; - } - else { - if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { - return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length)); - } - else { - return null; - } - } - }; - IORouterRegistry.registerSaveRouter(browserDownloadsRouter); - /** - * Creates an IOHandler that triggers file downloads from the browser. - * - * The returned `IOHandler` instance can be used as model exporting methods such - * as `tf.Model.save` and supports only saving. - * - * ```js - * const model = tf.sequential(); - * model.add(tf.layers.dense( - * {units: 1, inputShape: [10], activation: 'sigmoid'})); - * const saveResult = await model.save('downloads://mymodel'); - * // This will trigger downloading of two files: - * // 'mymodel.json' and 'mymodel.weights.bin'. - * console.log(saveResult); - * ``` - * - * @param fileNamePrefix Prefix name of the files to be downloaded. For use with - * `tf.Model`, `fileNamePrefix` should follow either of the following two - * formats: - * 1. `null` or `undefined`, in which case the default file - * names will be used: - * - 'model.json' for the JSON file containing the model topology and - * weights manifest. - * - 'model.weights.bin' for the binary file containing the binary weight - * values. - * 2. A single string or an Array of a single string, as the file name prefix. - * For example, if `'foo'` is provided, the downloaded JSON - * file and binary weights file will be named 'foo.json' and - * 'foo.weights.bin', respectively. - * @param config Additional configuration for triggering downloads. - * @returns An instance of `BrowserDownloads` `IOHandler`. - * - * @doc { - * heading: 'Models', - * subheading: 'Loading', - * namespace: 'io', - * ignoreCI: true - * } - */ - function browserDownloads(fileNamePrefix = 'model') { - return new BrowserDownloads(fileNamePrefix); - } - /** - * Creates an IOHandler that loads model artifacts from user-selected files. - * - * This method can be used for loading from files such as user-selected files - * in the browser. - * When used in conjunction with `tf.loadLayersModel`, an instance of - * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts. - * - * ```js - * // Note: This code snippet won't run properly without the actual file input - * // elements in the HTML DOM. - * - * // Suppose there are two HTML file input (``) - * // elements. - * const uploadJSONInput = document.getElementById('upload-json'); - * const uploadWeightsInput = document.getElementById('upload-weights'); - * const model = await tf.loadLayersModel(tf.io.browserFiles( - * [uploadJSONInput.files[0], uploadWeightsInput.files[0]])); - * ``` - * - * @param files `File`s to load from. Currently, this function supports only - * loading from files that contain Keras-style models (i.e., `tf.Model`s), for - * which an `Array` of `File`s is expected (in that order): - * - A JSON file containing the model topology and weight manifest. - * - Optionally, One or more binary files containing the binary weights. - * These files must have names that match the paths in the `weightsManifest` - * contained by the aforementioned JSON file, or errors will be thrown - * during loading. These weights files have the same format as the ones - * generated by `tensorflowjs_converter` that comes with the `tensorflowjs` - * Python PIP package. If no weights files are provided, only the model - * topology will be loaded from the JSON file above. - * @returns An instance of `Files` `IOHandler`. - * - * @doc { - * heading: 'Models', - * subheading: 'Loading', - * namespace: 'io', - * ignoreCI: true - * } - */ - function browserFiles(files) { - return new BrowserFiles(files); - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Monitor Promise.all progress, fire onProgress callback function. - * - * @param promises Promise list going to be monitored - * @param onProgress Callback function. Fired when a promise resolved. - * @param startFraction Optional fraction start. Default to 0. - * @param endFraction Optional fraction end. Default to 1. - */ - function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) { - checkPromises(promises); - startFraction = startFraction == null ? 0 : startFraction; - endFraction = endFraction == null ? 1 : endFraction; - checkFraction(startFraction, endFraction); - let resolvedPromise = 0; - const registerMonitor = (promise) => { - promise.then(value => { - const fraction = startFraction + - ++resolvedPromise / promises.length * (endFraction - startFraction); - // pass fraction as parameter to callback function. - onProgress(fraction); - return value; - }); - return promise; - }; - function checkPromises(promises) { - assert(promises != null && Array.isArray(promises) && promises.length > 0, () => 'promises must be a none empty array'); - } - function checkFraction(startFraction, endFraction) { - assert(startFraction >= 0 && startFraction <= 1, () => `Progress fraction must be in range [0, 1], but ` + - `got startFraction ${startFraction}`); - assert(endFraction >= 0 && endFraction <= 1, () => `Progress fraction must be in range [0, 1], but ` + - `got endFraction ${endFraction}`); - assert(endFraction >= startFraction, () => `startFraction must be no more than endFraction, but ` + - `got startFraction ${startFraction} and endFraction ` + - `${endFraction}`); - } - return Promise.all(promises.map(registerMonitor)); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Reads binary weights data from a number of URLs. - * - * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. - * @param requestOptions RequestInit (options) for the HTTP requests. - * @param fetchFunc Optional overriding value for the `window.fetch` function. - * @param onProgress Optional, progress callback function, fired periodically - * before the load is completed. - * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same - * length as `fetchURLs`. - */ - async function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { - if (loadOptions == null) { - loadOptions = {}; - } - const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : - loadOptions.fetchFunc; - // Create the requests for all of the weights in parallel. - const requests = fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true })); - const fetchStartFraction = 0; - const fetchEndFraction = 0.5; - const responses = loadOptions.onProgress == null ? - await Promise.all(requests) : - await monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction); - const bufferPromises = responses.map(response => response.arrayBuffer()); - const bufferStartFraction = 0.5; - const bufferEndFraction = 1; - const buffers = loadOptions.onProgress == null ? - await Promise.all(bufferPromises) : - await monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction); - return buffers; - } - /** - * Reads a weights manifest JSON configuration, fetches the weights and - * returns them as `Tensor`s. - * - * @param manifest The weights manifest JSON. - * @param filePathPrefix The path prefix for filenames given in the manifest. - * Defaults to the empty string. - * @param weightNames The names of the weights to be fetched. - */ - async function loadWeights(manifest, filePathPrefix = '', weightNames, requestInit) { - // TODO(nsthorat): Groups are currently fetched atomically. If you need a - // single weight from a group, the whole group will be fetched. At a future - // date, we should support fetching only the individual shards within a - // group that are needed to reconstruct the requested weight. - // TODO(cais): Use `decodeWeights` for implementation. - const fetchWeights = (fetchUrls) => loadWeightsAsArrayBuffer(fetchUrls, { requestInit }); - const loadWeights = weightsLoaderFactory(fetchWeights); - return loadWeights(manifest, filePathPrefix, weightNames); - } - /** - * Creates a function, which reads a weights manifest JSON configuration, - * fetches the weight files using the specified function and returns them as - * `Tensor`s. - * - * ```js - * // example for creating a nodejs weight loader, which reads the weight files - * // from disk using fs.readFileSync - * - * import * as fs from 'fs' - * - * const fetchWeightsFromDisk = (filePaths: string[]) => - * filePaths.map(filePath => fs.readFileSync(filePath).buffer) - * - * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk) - * - * const manifest = JSON.parse( - * fs.readFileSync('./my_model-weights_manifest').toString() - * ) - * const weightMap = await loadWeights(manifest, './') - * ``` - * @param fetchWeightsFunction The function used for fetching the weight files. - * @returns Weight loading function. - */ - function weightsLoaderFactory(fetchWeightsFunction) { - return async (manifest, filePathPrefix = '', weightNames) => { - // Collect all the groups, weights, and their relative offsets to be - // fetched. - const groupIndicesToFetchMap = manifest.map(() => false); - const groupWeightsToFetch = {}; - const weightsFound = weightNames != null ? weightNames.map(() => false) : []; - const allManifestWeightNames = []; - manifest.forEach((manifestGroupConfig, groupIndex) => { - let groupOffset = 0; - manifestGroupConfig.weights.forEach(weightsEntry => { - const rawDtype = ('quantization' in weightsEntry) ? - weightsEntry.quantization.dtype : - weightsEntry.dtype; - const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * - sizeFromShape(weightsEntry.shape); - const enqueueWeightsForFetchingFn = () => { - groupIndicesToFetchMap[groupIndex] = true; - if (groupWeightsToFetch[groupIndex] == null) { - groupWeightsToFetch[groupIndex] = []; - } - groupWeightsToFetch[groupIndex].push({ - manifestEntry: weightsEntry, - groupOffset, - sizeBytes: weightsBytes - }); - }; - if (weightNames != null) { - weightNames.forEach((weightName, weightIndex) => { - if (weightName === weightsEntry.name) { - enqueueWeightsForFetchingFn(); - weightsFound[weightIndex] = true; - } - }); - } - else { - enqueueWeightsForFetchingFn(); - } - allManifestWeightNames.push(weightsEntry.name); - groupOffset += weightsBytes; - }); - }); - if (!weightsFound.every(found => found)) { - const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]); - throw new Error(`Could not find weights in manifest with names: ` + - `${weightsNotFound.join(', ')}. \n` + - `Manifest JSON has weights with names: ` + - `${allManifestWeightNames.join(', ')}.`); - } - // Convert the one-hot boolean groupId => shouldFetch map to a list of group - // IDs. - const groupIndicesToFetch = groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => { - if (shouldFetch) { - accumulator.push(i); - } - return accumulator; - }, []); - const fetchUrls = []; - groupIndicesToFetch.forEach(i => { - manifest[i].paths.forEach(filepath => { - const fetchUrl = filePathPrefix + - (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; - fetchUrls.push(fetchUrl); - }); - }); - const buffers = await fetchWeightsFunction(fetchUrls); - const weightsTensorMap = {}; - let bufferIndexOffset = 0; - groupIndicesToFetch.forEach(i => { - const numBuffers = manifest[i].paths.length; - let groupBytes = 0; - for (let i = 0; i < numBuffers; i++) { - groupBytes += buffers[bufferIndexOffset + i].byteLength; - } - // Create a buffer for the whole group. - const groupBuffer = new ArrayBuffer(groupBytes); - const groupByteBuffer = new Uint8Array(groupBuffer); - let groupBufferOffset = 0; - for (let i = 0; i < numBuffers; i++) { - const buffer = new Uint8Array(buffers[bufferIndexOffset + i]); - groupByteBuffer.set(buffer, groupBufferOffset); - groupBufferOffset += buffer.byteLength; - } - const weightsEntries = groupWeightsToFetch[i]; - weightsEntries.forEach(weightsEntry => { - const byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); - const nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); - for (const name in nameToTensorMap) { - weightsTensorMap[name] = nameToTensorMap[name]; - } - }); - bufferIndexOffset += numBuffers; - }); - return weightsTensorMap; - }; - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - const OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; - const JSON_TYPE = 'application/json'; - class HTTPRequest { - constructor(path, loadOptions) { - this.DEFAULT_METHOD = 'POST'; - if (loadOptions == null) { - loadOptions = {}; - } - this.weightPathPrefix = loadOptions.weightPathPrefix; - this.onProgress = loadOptions.onProgress; - this.weightUrlConverter = loadOptions.weightUrlConverter; - if (loadOptions.fetchFunc != null) { - assert(typeof loadOptions.fetchFunc === 'function', () => 'Must pass a function that matches the signature of ' + - '`fetch` (see ' + - 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'); - this.fetch = loadOptions.fetchFunc; - } - else { - this.fetch = env().platform.fetch; - } - assert(path != null && path.length > 0, () => 'URL path for http must not be null, undefined or ' + - 'empty.'); - if (Array.isArray(path)) { - assert(path.length === 2, () => 'URL paths for http must have a length of 2, ' + - `(actual length is ${path.length}).`); - } - this.path = path; - if (loadOptions.requestInit != null && - loadOptions.requestInit.body != null) { - throw new Error('requestInit is expected to have no pre-existing body, but has one.'); - } - this.requestInit = loadOptions.requestInit || {}; - } - async save(modelArtifacts) { - if (modelArtifacts.modelTopology instanceof ArrayBuffer) { - throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' + - 'in binary formats yet.'); - } - const init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit); - init.body = new FormData(); - const weightsManifest = [{ - paths: ['./model.weights.bin'], - weights: modelArtifacts.weightSpecs, - }]; - const modelTopologyAndWeightManifest = { - modelTopology: modelArtifacts.modelTopology, - format: modelArtifacts.format, - generatedBy: modelArtifacts.generatedBy, - convertedBy: modelArtifacts.convertedBy, - userDefinedMetadata: modelArtifacts.userDefinedMetadata, - weightsManifest - }; - init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json'); - if (modelArtifacts.weightData != null) { - init.body.append('model.weights.bin', new Blob([modelArtifacts.weightData], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin'); - } - const response = await this.fetch(this.path, init); - if (response.ok) { - return { - modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts), - responses: [response], - }; - } - else { - throw new Error(`BrowserHTTPRequest.save() failed due to HTTP response status ` + - `${response.status}.`); - } - } - /** - * Load model artifacts via HTTP request(s). - * - * See the documentation to `tf.io.http` for details on the saved - * artifacts. - * - * @returns The loaded model artifacts (if loading succeeds). - */ - async load() { - const modelConfigRequest = await this.fetch(this.path, this.requestInit); - if (!modelConfigRequest.ok) { - throw new Error(`Request to ${this.path} failed with status code ` + - `${modelConfigRequest.status}. Please verify this URL points to ` + - `the model JSON of the model to load.`); - } - let modelConfig; - try { - modelConfig = await modelConfigRequest.json(); - } - catch (e) { - let message = `Failed to parse model JSON of response from ${this.path}.`; - // TODO(nsthorat): Remove this after some time when we're comfortable that - // .pb files are mostly gone. - if (this.path.endsWith('.pb')) { - message += ' Your path contains a .pb file extension. ' + - 'Support for .pb models have been removed in TensorFlow.js 1.0 ' + - 'in favor of .json models. You can re-convert your Python ' + - 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' + - 'or you can convert your.pb models with the \'pb2json\'' + - 'NPM script in the tensorflow/tfjs-converter repository.'; - } - else { - message += ' Please make sure the server is serving valid ' + - 'JSON for this request.'; - } - throw new Error(message); - } - const modelTopology = modelConfig.modelTopology; - const weightsManifest = modelConfig.weightsManifest; - const generatedBy = modelConfig.generatedBy; - const convertedBy = modelConfig.convertedBy; - const format = modelConfig.format; - const userDefinedMetadata = modelConfig.userDefinedMetadata; - // We do not allow both modelTopology and weightsManifest to be missing. - if (modelTopology == null && weightsManifest == null) { - throw new Error(`The JSON from HTTP path ${this.path} contains neither model ` + - `topology or manifest for weights.`); - } - let weightSpecs; - let weightData; - if (weightsManifest != null) { - const results = await this.loadWeights(weightsManifest); - [weightSpecs, weightData] = results; - } - const artifacts = { - modelTopology, - weightSpecs, - weightData, - userDefinedMetadata, - generatedBy, - convertedBy, - format - }; - const initializer = modelConfig.modelInitializer; - if (initializer) { - artifacts.modelInitializer = initializer; - } - return artifacts; - } - async loadWeights(weightsManifest) { - const weightPath = Array.isArray(this.path) ? this.path[1] : this.path; - const [prefix, suffix] = parseUrl(weightPath); - const pathPrefix = this.weightPathPrefix || prefix; - const weightSpecs = []; - for (const entry of weightsManifest) { - weightSpecs.push(...entry.weights); - } - const fetchURLs = []; - const urlPromises = []; - for (const weightsGroup of weightsManifest) { - for (const path of weightsGroup.paths) { - if (this.weightUrlConverter != null) { - urlPromises.push(this.weightUrlConverter(path)); - } - else { - fetchURLs.push(pathPrefix + path + suffix); - } - } - } - if (this.weightUrlConverter) { - fetchURLs.push(...await Promise.all(urlPromises)); - } - const buffers = await loadWeightsAsArrayBuffer(fetchURLs, { - requestInit: this.requestInit, - fetchFunc: this.fetch, - onProgress: this.onProgress - }); - return [weightSpecs, concatenateArrayBuffers(buffers)]; - } - } - HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//; - /** - * Extract the prefix and suffix of the url, where the prefix is the path before - * the last file, and suffix is the search params after the last file. - * ``` - * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file' - * [prefix, suffix] = parseUrl(url) - * // prefix = 'http://tfhub.dev/model/1/' - * // suffix = '?tfjs-format=file' - * ``` - * @param url the model url to be parsed. - */ - function parseUrl(url) { - const lastSlash = url.lastIndexOf('/'); - const lastSearchParam = url.lastIndexOf('?'); - const prefix = url.substring(0, lastSlash); - const suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : ''; - return [prefix + '/', suffix]; - } - function isHTTPScheme(url) { - return url.match(HTTPRequest.URL_SCHEME_REGEX) != null; - } - const httpRouter = (url, loadOptions) => { - if (typeof fetch === 'undefined' && - (loadOptions == null || loadOptions.fetchFunc == null)) { - // `http` uses `fetch` or `node-fetch`, if one wants to use it in - // an environment that is not the browser or node they have to setup a - // global fetch polyfill. - return null; - } - else { - let isHTTP = true; - if (Array.isArray(url)) { - isHTTP = url.every(urlItem => isHTTPScheme(urlItem)); - } - else { - isHTTP = isHTTPScheme(url); - } - if (isHTTP) { - return http(url, loadOptions); - } - } - return null; - }; - IORouterRegistry.registerSaveRouter(httpRouter); - IORouterRegistry.registerLoadRouter(httpRouter); - /** - * Creates an IOHandler subtype that sends model artifacts to HTTP server. - * - * An HTTP request of the `multipart/form-data` mime type will be sent to the - * `path` URL. The form data includes artifacts that represent the topology - * and/or weights of the model. In the case of Keras-style `tf.Model`, two - * blobs (files) exist in form-data: - * - A JSON file consisting of `modelTopology` and `weightsManifest`. - * - A binary weights file consisting of the concatenated weight values. - * These files are in the same format as the one generated by - * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html). - * - * The following code snippet exemplifies the client-side code that uses this - * function: - * - * ```js - * const model = tf.sequential(); - * model.add( - * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); - * - * const saveResult = await model.save(tf.io.http( - * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}})); - * console.log(saveResult); - * ``` - * - * If the default `POST` method is to be used, without any custom parameters - * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`: - * - * ```js - * const saveResult = await model.save('http://model-server:5000/upload'); - * ``` - * - * The following GitHub Gist - * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864 - * implements a server based on [flask](https://github.com/pallets/flask) that - * can receive the request. Upon receiving the model artifacts via the requst, - * this particular server reconsistutes instances of [Keras - * Models](https://keras.io/models/model/) in memory. - * - * - * @param path A URL path to the model. - * Can be an absolute HTTP path (e.g., - * 'http://localhost:8000/model-upload)') or a relative path (e.g., - * './model-upload'). - * @param requestInit Request configurations to be used when sending - * HTTP request to server using `fetch`. It can contain fields such as - * `method`, `credentials`, `headers`, `mode`, etc. See - * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request - * for more information. `requestInit` must not have a body, because the - * body will be set by TensorFlow.js. File blobs representing the model - * topology (filename: 'model.json') and the weights of the model (filename: - * 'model.weights.bin') will be appended to the body. If `requestInit` has a - * `body`, an Error will be thrown. - * @param loadOptions Optional configuration for the loading. It includes the - * following fields: - * - weightPathPrefix Optional, this specifies the path prefix for weight - * files, by default this is calculated from the path param. - * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js, - * the `fetch` from node-fetch can be used here. - * - onProgress Optional, progress callback function, fired periodically - * before the load is completed. - * @returns An instance of `IOHandler`. - * - * @doc { - * heading: 'Models', - * subheading: 'Loading', - * namespace: 'io', - * ignoreCI: true - * } - */ - function http(path, loadOptions) { - return new HTTPRequest(path, loadOptions); - } - /** - * Deprecated. Use `tf.io.http`. - * @param path - * @param loadOptions - */ - function browserHTTPRequest(path, loadOptions) { - return http(path, loadOptions); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - class PassthroughLoader { - constructor(modelArtifacts) { - this.modelArtifacts = modelArtifacts; - } - async load() { - return this.modelArtifacts; - } - } - class PassthroughSaver { - constructor(saveHandler) { - this.saveHandler = saveHandler; - } - async save(modelArtifacts) { - return this.saveHandler(modelArtifacts); - } - } - /** - * Creates an IOHandler that loads model artifacts from memory. - * - * When used in conjunction with `tf.loadLayersModel`, an instance of - * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts. - * - * ```js - * const model = await tf.loadLayersModel(tf.io.fromMemory( - * modelTopology, weightSpecs, weightData)); - * ``` - * - * @param modelArtifacts a object containing model topology (i.e., parsed from - * the JSON format). - * @param weightSpecs An array of `WeightsManifestEntry` objects describing the - * names, shapes, types, and quantization of the weight data. - * @param weightData A single `ArrayBuffer` containing the weight data, - * concatenated in the order described by the weightSpecs. - * @param trainingConfig Model training configuration. Optional. - * - * @returns A passthrough `IOHandler` that simply loads the provided data. - */ - function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) { - if (arguments.length === 1) { - const isModelArtifacts = modelArtifacts.modelTopology != null || - modelArtifacts.weightSpecs != null; - if (isModelArtifacts) { - return new PassthroughLoader(modelArtifacts); - } - else { - // Legacy support: with only modelTopology. - // TODO(cais): Remove this deprecated API. - console.warn('Please call tf.io.fromMemory() with only one argument. ' + - 'The argument should be of type ModelArtifacts. ' + - 'The multi-argument signature of tf.io.fromMemory() has been ' + - 'deprecated and will be removed in a future release.'); - return new PassthroughLoader({ modelTopology: modelArtifacts }); - } - } - else { - // Legacy support. - // TODO(cais): Remove this deprecated API. - console.warn('Please call tf.io.fromMemory() with only one argument. ' + - 'The argument should be of type ModelArtifacts. ' + - 'The multi-argument signature of tf.io.fromMemory() has been ' + - 'deprecated and will be removed in a future release.'); - return new PassthroughLoader({ - modelTopology: modelArtifacts, - weightSpecs, - weightData, - trainingConfig - }); - } - } - /** - * Creates an IOHandler that passes saved model artifacts to a callback. - * - * ```js - * function handleSave(artifacts) { - * // ... do something with the artifacts ... - * return {modelArtifactsInfo: {...}, ...}; - * } - * - * const saveResult = model.save(tf.io.withSaveHandler(handleSave)); - * ``` - * - * @param saveHandler A function that accepts a `ModelArtifacts` and returns a - * `SaveResult`. - */ - function withSaveHandler(saveHandler) { - return new PassthroughSaver(saveHandler); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - - var io = /*#__PURE__*/Object.freeze({ - __proto__: null, - browserFiles: browserFiles, - browserHTTPRequest: browserHTTPRequest, - concatenateArrayBuffers: concatenateArrayBuffers, - decodeWeights: decodeWeights, - encodeWeights: encodeWeights, - fromMemory: fromMemory, - getLoadHandlers: getLoadHandlers, - getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON, - getSaveHandlers: getSaveHandlers, - http: http, - isHTTPScheme: isHTTPScheme, - loadWeights: loadWeights, - registerLoadRouter: registerLoadRouter, - registerSaveRouter: registerSaveRouter, - weightsLoaderFactory: weightsLoaderFactory, - withSaveHandler: withSaveHandler, - copyModel: copyModel, - listModels: listModels, - moveModel: moveModel, - removeModel: removeModel - }); - - /** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Reshapes a `tf.Tensor` to a given shape. - * - * Given an input tensor, returns a new tensor with the same values as the - * input tensor with shape `shape`. - * - * If one component of shape is the special value -1, the size of that - * dimension is computed so that the total size remains constant. In - * particular, a shape of [-1] flattens into 1-D. At most one component of - * shape can be -1. - * - * If shape is 1-D or higher, then the operation returns a tensor with shape - * shape filled with the values of tensor. In this case, the number of - * elements implied by shape must be the same as the number of elements in - * tensor. - * - * ```js - * const x = tf.tensor1d([1, 2, 3, 4]); - * x.reshape([2, 2]).print(); - * ``` - * - * @param x The input tensor to be reshaped. - * @param shape An array of integers defining the output tensor shape. - * - * @doc {heading: 'Tensors', subheading: 'Transformations'} - */ - function reshape_(x, shape) { - const $x = convertToTensor(x, 'x', 'reshape', null); - const inputs = { x: $x }; - const attrs = { shape }; - const forward = (backend, save) => { - shape = inferFromImplicitShape(shape, $x.size); - assert($x.size === sizeFromShape(shape), () => 'new shape and old shape must have the same number of elements.'); - save([$x]); - return backend.reshape($x, shape); - }; - return ENGINE.runKernelFunc(forward, inputs, null /* grad */, Reshape, attrs); - } - const reshape = op({ reshape_ }); - - /** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Computes the dot product of two matrices, A * B. These must be matrices. - * - * ```js - * const a = tf.tensor2d([1, 2], [1, 2]); - * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * - * a.matMul(b).print(); // or tf.matMul(a, b) - * ``` - * @param a First matrix in dot product operation. - * @param b Second matrix in dot product operation. - * @param transposeA If true, `a` is transposed before multiplication. - * @param transposeB If true, `b` is transposed before multiplication. - * - * @doc {heading: 'Operations', subheading: 'Matrices'} - */ - function matMul_(a, b, transposeA = false, transposeB = false) { - let $a = convertToTensor(a, 'a', 'matMul'); - let $b = convertToTensor(b, 'b', 'matMul'); - [$a, $b] = makeTypesMatch($a, $b); - assert($a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, () => `Error in matMul: inputs must have the same rank of at least 2, ` + - `got ranks ${$a.rank} and ${$b.rank}.`); - const forward = (backend, save) => { - save([$a, $b]); - const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; - const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; - const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; - const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; - const outerDimsA = $a.shape.slice(0, -2); - const outerDimsB = $b.shape.slice(0, -2); - assert(arraysEqual(outerDimsA, outerDimsB), () => `Error in matMul: outer dimensions (${outerDimsA}) and (` + - `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` + - `${$b.shape} must match.`); - assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` + - `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` + - `${$b.shape} and transposeA=${transposeA}` + - ` and transposeB=${transposeB} must match.`); - const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]); - const batchDimA = sizeFromShape(outerDimsA); - const batchDimB = sizeFromShape(outerDimsB); - const a3D = transposeA ? - reshape($a, [batchDimA, innerShapeA, outerShapeA]) : - reshape($a, [batchDimA, outerShapeA, innerShapeA]); - const b3D = transposeB ? - reshape($b, [batchDimB, outerShapeB, innerShapeB]) : - reshape($b, [batchDimB, innerShapeB, outerShapeB]); - const res3d = backend.batchMatMul(a3D, b3D, transposeA, transposeB); - return reshape(res3d, outShape); - }; - const inputs = { a: $a, b: $b }; - const attrs = { transposeA, transposeB }; - return ENGINE.runKernelFunc(forward, inputs, null /* grad */, BatchMatMul, attrs); - } - const matMul = op({ matMul_ }); - - /** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take - * value `onValue` (defaults to 1), while all other locations take value - * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank - * `R+1` with the last axis of size `depth`. - * - * ```js - * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print(); - * ``` - * - * @param indices `tf.Tensor` of indices with dtype `int32`. - * @param depth The depth of the one hot dimension. - * @param onValue A number used to fill in the output when the index matches - * the location. - * @param offValue A number used to fill in the output when the index does - * not match the location. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - function oneHot_(indices, depth, onValue = 1, offValue = 0) { - if (depth < 2) { - throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`); - } - const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); - const outShape = [...$indices.shape, depth]; - const forward = (backend, save) => { - save([$indices]); - return reshape(backend.oneHot(reshape($indices, [$indices.size]), depth, onValue, offValue), outShape); - }; - const inputs = { indices: $indices }; - const attrs = { depth, onValue, offValue }; - return ENGINE.runKernelFunc(forward, inputs, null /* grad */, OneHot, attrs); - } - const oneHot = op({ oneHot_ }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`. - * - * The returned `tf.Tensor`'s dimension `i` will correspond to the input - * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`, - * where `n` is the rank of the input `tf.Tensor`. Hence by default, this - * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s. - * - * ```js - * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - * - * a.transpose().print(); // or tf.transpose(a) - * ``` - * - * @param x The tensor to transpose. - * @param perm The permutation of the dimensions of a. - * - * @doc {heading: 'Operations', subheading: 'Matrices'} - */ - function transpose_(x, perm) { - const $x = convertToTensor(x, 'x', 'transpose'); - if (perm == null) { - perm = $x.shape.map((s, i) => i).reverse(); - } - assert($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} ` + - `must match length of perm ${perm}.`); - perm.forEach(axis => { - assert(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` + - ` but got ${perm}`); - }); - if ($x.rank <= 1) { - return $x.clone(); - } - const inputs = { x: $x }; - const attrs = { perm }; - return ENGINE.runKernelFunc(backend => backend.transpose($x, perm), inputs, null /* gradient */, Transpose, attrs); - } - const transpose = op({ transpose_ }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Computes the confusion matrix from true labels and predicted labels. - * - * ```js - * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32'); - * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32'); - * const numClasses = 3; - * const out = tf.math.confusionMatrix(labels, predictions, numClasses); - * out.print(); - * // Expected output matrix: - * // [[2, 0, 0], - * // [0, 1, 1], - * // [0, 0, 1]] - * ``` - * - * @param labels The target labels, assumed to be 0-based integers - * for the classes. The shape is `[numExamples]`, where - * `numExamples` is the number of examples included. - * @param predictions The predicted classes, assumed to be - * 0-based integers for the classes. Must have the same shape as `labels`. - * @param numClasses Number of all classes, as an integer. - * Its value must be larger than the largest element in `labels` and - * `predictions`. - * @returns The confusion matrix as a int32-type 2D tensor. The value at - * row `r` and column `c` is the number of times examples of actual class - * `r` were predicted as class `c`. - * - * @doc {heading: 'Operations', subheading: 'Evaluation'} - */ - function confusionMatrix_(labels, predictions, numClasses) { - const $labels = convertToTensor(labels, 'labels', 'confusionMatrix'); - const $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix'); - assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), () => `If provided, numClasses must be a positive integer, ` + - `but got ${numClasses}`); - assert($labels.rank === 1, () => `Expected the rank of labels to be 1, but got ${$labels.rank}`); - assert($predictions.rank === 1, () => `Expected the rank of predictions to be 1, ` + - `but got ${$predictions.rank}`); - assert($labels.shape[0] === $predictions.shape[0], () => `Mismatch in the number of examples: ` + - `${$labels.shape[0]} vs. ${$predictions.shape[0]}. ` + - `Labels and predictions should have the same number of elements.`); - assert(numClasses > 0 && Number.isInteger(numClasses), () => `numClasses is required to be a positive integer, but got ` + - `${numClasses}`); - // TODO(cais): In the future, if oneHot supports tensors inputs for - // `numClasses`, `confusionMatrix` can make `numClasses` optional. - const oneHotLabels = oneHot(cast($labels, 'int32'), numClasses); - const oneHotPredictions = oneHot(cast($predictions, 'int32'), numClasses); - const oneHotLabelsT = transpose(oneHotLabels); - return cast(matMul(oneHotLabelsT, oneHotPredictions), 'int32'); - } - const confusionMatrix = op({ confusionMatrix_ }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - - var math = /*#__PURE__*/Object.freeze({ - __proto__: null, - confusionMatrix: confusionMatrix - }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype. - * - * The same functionality can be achieved with `tf.tensor`, but in general - * we recommend using `tf.tensor3d` as it makes the code more readable. - * - * ```js - * // Pass a nested array. - * tf.tensor3d([[[1], [2]], [[3], [4]]]).print(); - * ``` - * ```js - * // Pass a flat array and specify a shape. - * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print(); - * ``` - * - * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. - * @param shape The shape of the tensor. If not provided, it is inferred from - * `values`. - * @param dtype The data type. - * - * @doc {heading: 'Tensors', subheading: 'Creation'} - */ - function tensor3d(values, shape, dtype) { - assertNonNull(values); - if (shape != null && shape.length !== 3) { - throw new Error('tensor3d() requires shape to have three numbers'); - } - const inferredShape = inferShape(values, dtype); - if (inferredShape.length !== 3 && inferredShape.length !== 1) { - throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray'); - } - if (inferredShape.length === 1 && shape == null) { - throw new Error('tensor3d() requires shape to be provided when `values` ' + - 'are a flat array'); - } - return makeTensor(values, shape, inferredShape, dtype); - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - let fromPixels2DContext; - /** - * Creates a `tf.Tensor` from an image. - * - * ```js - * const image = new ImageData(1, 1); - * image.data[0] = 100; - * image.data[1] = 150; - * image.data[2] = 200; - * image.data[3] = 255; - * - * tf.browser.fromPixels(image).print(); - * ``` - * - * @param pixels The input image to construct the tensor from. The - * supported image types are all 4-channel. You can also pass in an image - * object with following attributes: - * `{data: Uint8Array; width: number; height: number}` - * @param numChannels The number of channels of the output tensor. A - * numChannels value less than 4 allows you to ignore channels. Defaults to - * 3 (ignores alpha channel of input image). - * - * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true} - */ - function fromPixels_(pixels, numChannels = 3) { - // Sanity checks. - if (numChannels > 4) { - throw new Error('Cannot construct Tensor with more than 4 channels from pixels.'); - } - if (pixels == null) { - throw new Error('pixels passed to tf.browser.fromPixels() can not be null'); - } - let isPixelData = false; - let isImageData = false; - let isVideo = false; - let isImage = false; - let isCanvasLike = false; - if (pixels.data instanceof Uint8Array) { - isPixelData = true; - } - else if (typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) { - isImageData = true; - } - else if (typeof (HTMLVideoElement) !== 'undefined' && - pixels instanceof HTMLVideoElement) { - isVideo = true; - } - else if (typeof (HTMLImageElement) !== 'undefined' && - pixels instanceof HTMLImageElement) { - isImage = true; - // tslint:disable-next-line: no-any - } - else if (pixels.getContext != null) { - isCanvasLike = true; - } - else { - throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' + - `HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` + - `in browser, or OffscreenCanvas, ImageData in webworker` + - ` or {data: Uint32Array, width: number, height: number}, ` + - `but was ${pixels.constructor.name}`); - } - if (isVideo) { - const HAVE_CURRENT_DATA_READY_STATE = 2; - if (isVideo && - pixels.readyState < - HAVE_CURRENT_DATA_READY_STATE) { - throw new Error('The video element has not loaded data yet. Please wait for ' + - '`loadeddata` event on the